Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Post processing.""" | |
import torch | |
import torch.nn.functional as F | |
# pylint: disable=g-bad-import-order | |
# pylint: disable=g-importing-member | |
from modeling.post_process.object_discovery import get_instances | |
from utils.metrics import IoM | |
# This should be a abstract function to generate masks for the input image. | |
# However, we first hack it due to the time limit. | |
def generate_masks_from_sam( | |
image_path, save_path, pipeline, img_sam=None, visualize=True | |
): | |
"""Generate masks from SAM.""" | |
masks, _, mask_list = pipeline.segment_automask( | |
image_path=image_path, | |
visualize=visualize, | |
save_path=save_path, | |
image=img_sam, | |
) | |
mask_tensor = torch.from_numpy(masks) | |
mask_tensor = mask_tensor.float() | |
return mask_tensor, mask_list | |
def match_masks( | |
mask_tensor, attn_map, mask_list, iom_thres=0.0, min_pred_threshold=0.2 | |
): | |
"""Match masks with the attention map according to the IoU. | |
Args: | |
mask_tensor: A torch.Tensor for the masks with shape [num_masks, height, | |
width]. | |
attn_map: A torch.Tensor for the attention map with shape [1, 1, height, | |
width]. | |
mask_list: A list of masks with shape [num_masks, height, width] | |
iom_thres: A float for the threshold to apply to the attention map. | |
min_pred_threshold: The prediction score threshold. | |
Returns: | |
A list of matched_masks with shape [num_masks, height, width], | |
len(matched_masks) = number of captions | |
""" | |
predictions = attn_map.squeeze(1).detach() | |
iom = IoM(predictions, mask_tensor, min_pred_threshold=min_pred_threshold) | |
keep_mask = iom > iom_thres | |
# mask_tensor = mask_tensor[keep_mask] | |
new_list = [] | |
for mid, m_dict in enumerate(mask_list): | |
if keep_mask[mid]: | |
new_list.append(m_dict) | |
# if not len(new_list): | |
if not new_list: | |
max_id = torch.argmax(iom) | |
new_list.append(mask_list[max_id]) | |
return new_list | |
def post_process_mask(attn_masks, pad=None, min_area_ratio=0.15): | |
"""Post process attention masks.""" | |
if pad is not None: | |
left, top, width, height = pad | |
attn_masks = attn_masks[Ellipsis, top : top + height, left : left + width] | |
else: | |
height = None | |
width = None | |
mask_area = attn_masks.sum(dim=(1, 2)) | |
total_area = mask_area.sum() | |
keep_mask = mask_area / total_area > min_area_ratio | |
if torch.sum(keep_mask) == 0: | |
if keep_mask.shape[0] == 0: | |
return torch.zeros( | |
(1, height, width), device=attn_masks.device, dtype=attn_masks.dtype | |
) | |
keep_mask[torch.argmax(mask_area)] = True | |
attn_masks = attn_masks[keep_mask] | |
return attn_masks | |
def filter_masks( | |
attn_masks, | |
pad=None, | |
mask_threshold=0.3, | |
min_area_ratio=0.15, | |
return_largest=False, | |
device=None, | |
return_instances=False, | |
): | |
"""Filter attention mask below the threshold.""" | |
attn_masks[attn_masks < mask_threshold] = 0 | |
# get_instances will be operated on cpu | |
ins_masks = get_instances(attn_masks, return_largest=return_largest) | |
ins_masks = [post_process_mask(m, pad, min_area_ratio) for m in ins_masks] | |
ins_masks = list(filter(lambda x: x is not None, ins_masks)) | |
ins_masks = [m.to(device) for m in ins_masks] | |
if not return_instances: | |
return [torch.any(m, dim=0, keepdim=True).to(m.dtype) for m in ins_masks] | |
return ins_masks | |
def post_process( | |
input_array, | |
attn_masks, | |
pad=None, | |
mask_threshold=0.3, | |
return_largest=False, | |
min_area_ratio=0.15, | |
return_instances=False, | |
): | |
"""post process the input tensor with the attention masks. | |
Args: | |
input_array: A np.ndarray input array to be post processed with shape | |
[width, height, 3, batch_size] | |
attn_masks: A torch.Tensor for the attention masks with shape [1, | |
num_texts, width, height] | |
pad: A list of padding: [pad_left, pad_top, width, height], where | |
pad_left, pad_top and width, height are int values. | |
mask_threshold: The threshold to binarize the mask. | |
return_largest: If true, return the largest connected component. | |
min_area_ratio: Keep the mask if its area is larger than this threshold. | |
return_instances: Whether to return instances or not. | |
Returns: | |
attn_masks: A list of tensors with shape [num_instances, height, width] | |
x num_texts, where len(attn_masks) = num_texts. | |
NOTE: the number_instances for each text (class) may vary. | |
The output is a binary tensor. | |
""" | |
if len(attn_masks.shape) == 3: | |
attn_masks = attn_masks[None] | |
img_width, img_height = input_array.shape[:2] | |
attn_masks = F.interpolate( | |
attn_masks, size=(img_height, img_width), mode='bicubic' | |
).squeeze(0) | |
device = attn_masks.device | |
output_masks = filter_masks( | |
attn_masks, | |
pad=pad, | |
mask_threshold=mask_threshold, | |
min_area_ratio=min_area_ratio, | |
return_largest=return_largest, | |
device=device, | |
return_instances=return_instances, | |
) | |
if pad is not None: | |
left, top, width, height = pad | |
input_array = input_array[top : top + height, left : left + width] | |
return input_array, output_masks | |