Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""CAM utils.""" | |
# pylint: disable=g-importing-member | |
import os | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from scipy.ndimage import binary_fill_holes | |
import torch | |
from torchvision.transforms import Compose | |
from torchvision.transforms import Normalize | |
from torchvision.transforms import Resize | |
from torchvision.transforms import ToTensor | |
# pylint: disable=g-import-not-at-top | |
try: | |
from torchvision.transforms import InterpolationMode | |
BICUBIC = InterpolationMode.BICUBIC | |
except ImportError: | |
BICUBIC = Image.BICUBIC | |
_CONTOUR_INDEX = 1 if cv2.__version__.split('.')[0] == '3' else 0 | |
def _convert_image_to_rgb(image): | |
return image.convert('RGB') | |
def _transform_resize(h, w): | |
return Compose([ | |
Resize((h, w), interpolation=BICUBIC), | |
_convert_image_to_rgb, | |
ToTensor(), | |
Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
]) | |
def img_ms_and_flip(image, ori_height, ori_width, scales=1.0, patch_size=16): | |
"""Resizes and flips the image.""" | |
if isinstance(scales, float): | |
scales = [scales] | |
all_imgs = [] | |
for scale in scales: | |
preprocess = _transform_resize( | |
int(np.ceil(scale * int(ori_height) / patch_size) * patch_size), | |
int(np.ceil(scale * int(ori_width) / patch_size) * patch_size), | |
) | |
image = preprocess(image) | |
image_ori = image | |
image_flip = torch.flip(image, [-1]) | |
all_imgs.append(image_ori) | |
all_imgs.append(image_flip) | |
return all_imgs | |
def reshape_transform(tensor, height=28, width=28): | |
tensor = tensor.permute(1, 0, 2) | |
result = tensor[:, 1:, :].reshape( | |
tensor.size(0), height, width, tensor.size(2) | |
) | |
# Bring the channels to the first dimension, like in CNNs. | |
result = result.transpose(2, 3).transpose(1, 2) | |
return result | |
def vis_mask(image, mask, mask_color): | |
# switch the height and width of image | |
# image = image.transpose(1, 0, 2) | |
if mask.shape[0] != image.shape[0] or mask.shape[1] != image.shape[1]: | |
mask = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
fg = mask > 0.5 | |
rgb = np.copy(image) | |
rgb[fg] = (rgb[fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8) | |
return Image.fromarray(rgb) | |
def scoremap2bbox(scoremap, threshold, multi_contour_eval=False): | |
"""Get bounding boxes from scoremap.""" | |
height, width = scoremap.shape | |
scoremap_image = np.expand_dims((scoremap * 255).astype(np.uint8), 2) | |
while True: | |
_, thr_gray_heatmap = cv2.threshold( | |
src=scoremap_image, | |
thresh=int(threshold * np.max(scoremap_image)), | |
maxval=255, | |
type=cv2.THRESH_BINARY, | |
) | |
if thr_gray_heatmap.max() > 0 or threshold <= 0: | |
break | |
threshold -= 0.1 | |
contours = cv2.findContours( | |
image=thr_gray_heatmap, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE | |
)[_CONTOUR_INDEX] | |
# if len(contours) == 0: | |
if not contours: | |
return np.asarray([[0, 0, 0, 0]]), 1 | |
if not multi_contour_eval: | |
contours = [max(contours, key=cv2.contourArea)] | |
estimated_boxes = [] | |
for contour in contours: | |
x, y, w, h = cv2.boundingRect(contour) | |
x0, y0, x1, y1 = x, y, x + w, y + h | |
x1 = min(x1, width - 1) | |
y1 = min(y1, height - 1) | |
estimated_boxes.append([x0, y0, x1, y1]) | |
return np.asarray(estimated_boxes), len(contours) | |
def mask2chw(arr): | |
# Find the row and column indices where the array is 1 | |
rows, cols = np.where(arr == 1) | |
# Calculate center of the mask | |
center_y = int(np.mean(rows)) | |
center_x = int(np.mean(cols)) | |
# Calculate height and width of the mask | |
height = rows.max() - rows.min() + 1 | |
width = cols.max() - cols.min() + 1 | |
return (center_y, center_x), height, width | |
def unpad(image_array, pad=None): | |
if pad is not None: | |
left, top, width, height = pad | |
image_array = image_array[top : top + height, left : left + width, :] | |
return image_array | |
def apply_visual_prompts( | |
image_array, | |
mask, | |
visual_prompt_type=('circle',), | |
visualize=False, | |
color=(255, 0, 0), | |
thickness=1, | |
blur_strength=(15, 15), | |
): | |
"""Applies visual prompts to the image.""" | |
prompted_image = image_array.copy() | |
if 'blur' in visual_prompt_type: | |
# blur the part out side the mask | |
# Blur the entire image | |
blurred = cv2.GaussianBlur(prompted_image.copy(), blur_strength, 0) | |
# Get the sharp region using the mask | |
sharp_region = cv2.bitwise_and( | |
prompted_image.copy(), | |
prompted_image.copy(), | |
mask=np.clip(mask, 0, 255).astype(np.uint8), | |
) | |
# Get the blurred region using the inverted mask | |
inv_mask = 1 - mask | |
blurred_region = (blurred * inv_mask[:, :, None]).astype(np.uint8) | |
# Combine the sharp and blurred regions | |
prompted_image = cv2.add(sharp_region, blurred_region) | |
if 'gray' in visual_prompt_type: | |
gray = cv2.cvtColor(prompted_image.copy(), cv2.COLOR_BGR2GRAY) | |
# make gray part 3 channel | |
gray = np.stack([gray, gray, gray], axis=-1) | |
# Get the sharp region using the mask | |
color_region = cv2.bitwise_and( | |
prompted_image.copy(), | |
prompted_image.copy(), | |
mask=np.clip(mask, 0, 255).astype(np.uint8), | |
) | |
# Get the blurred region using the inverted mask | |
inv_mask = 1 - mask | |
gray_region = (gray * inv_mask[:, :, None]).astype(np.uint8) | |
# Combine the sharp and blurred regions | |
prompted_image = cv2.add(color_region, gray_region) | |
if 'black' in visual_prompt_type: | |
prompted_image = cv2.bitwise_and( | |
prompted_image.copy(), | |
prompted_image.copy(), | |
mask=np.clip(mask, 0, 255).astype(np.uint8), | |
) | |
if 'circle' in visual_prompt_type: | |
mask_center, mask_height, mask_width = mask2chw(mask) | |
center_coordinates = (mask_center[1], mask_center[0]) | |
axes_length = (mask_width // 2, mask_height // 2) | |
prompted_image = cv2.ellipse( | |
prompted_image, | |
center_coordinates, | |
axes_length, | |
0, | |
0, | |
360, | |
color, | |
thickness, | |
) | |
if 'rectangle' in visual_prompt_type: | |
mask_center, mask_height, mask_width = mask2chw(mask) | |
# center_coordinates = (mask_center[1], mask_center[0]) | |
# axes_length = (mask_width // 2, mask_height // 2) | |
start_point = ( | |
mask_center[1] - mask_width // 2, | |
mask_center[0] - mask_height // 2, | |
) | |
end_point = ( | |
mask_center[1] + mask_width // 2, | |
mask_center[0] + mask_height // 2, | |
) | |
prompted_image = cv2.rectangle( | |
prompted_image, start_point, end_point, color, thickness | |
) | |
if 'contour' in visual_prompt_type: | |
# Find the contours of the mask | |
# fill holes for the mask | |
mask = binary_fill_holes(mask) | |
contours, _ = cv2.findContours( | |
mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE | |
) | |
# Draw the contours on the image | |
prompted_image = cv2.drawContours( | |
prompted_image.copy(), contours, -1, color, thickness | |
) | |
if visualize: | |
cv2.imwrite(os.path.join('masked_img.png'), prompted_image) | |
prompted_image = Image.fromarray(prompted_image.astype(np.uint8)) | |
return prompted_image | |