Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implementation of CaR.""" | |
import os | |
import clip | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
# pylint: disable=g-importing-member | |
# pylint: disable=g-bad-import-order | |
from modeling.model.clip_wrapper import CLIPWrapper | |
from modeling.model.clip_wrapper import forward_clip | |
from modeling.model.clipcam import CLIPCAM | |
from modeling.model.crf import PostProcess | |
from modeling.model.utils import apply_visual_prompts | |
from utils.visualize import viz_attn | |
class CaR(nn.Module): | |
"""CaR module.""" | |
def __init__( | |
self, | |
cfg, | |
device="cpu", | |
visualize=False, | |
confidence_threshold=0.45, | |
save_path="save_path", | |
seg_mode="refer", | |
semantic_clip_model_name=None, | |
semantic_pretrained_data=None, | |
semantic_templates=None, | |
text_template=None, | |
visual_prompt_type="circle", | |
clipes_threshold=0.4, | |
cam_text_template="a clean origami {}.", | |
bg_cls=None, | |
iom_thres=0.6, | |
min_pred_threshold=0.01, | |
bg_factor=1.0, | |
mask_threshold=0.5, | |
): | |
"""CaR model for image segmentation. | |
Args: | |
cfg: the config file. | |
device: the device to run the model. | |
visualize: whether to visualize the intermediate results | |
confidence_threshold: the confidence threshold for semantic | |
segmentation. If the confidence score is lower than the threshold, the | |
mask will be discarded. | |
save_path: the path to save the intermediate results | |
seg_mode: the segmentation mode, can be 'refer' or 'semantic' | |
semantic_clip_model_name: the name of the semantic segmentation model. | |
semantic_pretrained_data: the path to the pretrained semantic | |
segmentation model. | |
semantic_templates: the templates for semantic segmentation. | |
text_template: the template for visual prompting. | |
visual_prompt_type: the type of visual prompting. | |
clipes_threshold: the threshold for CLIPES. | |
cam_text_template: the template for CAM. | |
bg_cls: background classes. | |
iom_thres: IoM threshold. | |
min_pred_threshold: Prediction threshold. | |
bg_factor: Background factor. | |
mask_threshold: Mask threshold. | |
""" | |
super(CaR, self).__init__() | |
# CLIP parameters | |
self.confidence_threshold = confidence_threshold | |
self.device = device | |
self.visualize = visualize | |
self.save_path = save_path | |
self.seg_mode = seg_mode | |
self.semantic_clip_model_name = semantic_clip_model_name | |
self.semantic_pretrained_data = semantic_pretrained_data | |
self.semantic_templates = semantic_templates | |
self.text_template = text_template | |
self.visual_prompt_type = visual_prompt_type | |
self.clipes_threshold = clipes_threshold | |
self.cam_text_template = cam_text_template | |
self.iom_thres = iom_thres | |
self.min_pred_threshold = min_pred_threshold | |
self.bg_cls = bg_cls | |
self.bg_factor = bg_factor | |
self.mask_threshold = mask_threshold | |
if not hasattr(cfg, "clip"): | |
raise ValueError("The config file should contain the CLIP parameters.") | |
if not hasattr(cfg, "car"): | |
raise ValueError("The config file should contain the car parameters.") | |
if hasattr(cfg, "cam"): | |
raise ValueError("cfg.cam is deprecated, please use cfg.car ") | |
for k, v in vars(cfg.clip).items(): | |
setattr(self, k, v) | |
for k, v in vars(cfg.car).items(): | |
setattr(self, k, v) | |
if hasattr(cfg, "sam"): | |
for k, v in vars(cfg.sam).items(): | |
setattr(self, k, v) | |
if not self.bg_cls: | |
self.bg_cls = None | |
print(f"The model is running on {self.device}") | |
self.clip_model, self.preprocess = clip.load( | |
self.clip_model_name, device=self.device | |
) | |
self.clip_model = CLIPWrapper(self.clip_model) | |
self.post_process = PostProcess(device=self.device) | |
self.mask_generator = CLIPCAM( | |
self.clip_model, | |
device=self.device, | |
text_template=self.text_template, | |
threshold=self.clipes_threshold, | |
bg_cls=self.bg_cls, | |
) | |
self.semantic_clip_model, self.semantic_preprocess = clip.load( | |
self.semantic_clip_model_name, device=self.device | |
) | |
self.semantic_clip_model = CLIPWrapper(self.semantic_clip_model) | |
def get_confidence(self, cam_map, binary_cam_map): | |
confidence_map = torch.sum(cam_map * binary_cam_map[None], dim=[2, 3]) | |
confidence_map = confidence_map / torch.sum(binary_cam_map, dim=[1, 2]) | |
confidence_score = confidence_map.squeeze() | |
return confidence_score | |
def set_visual_prompt_type(self, visual_prompt_type): | |
self.visual_prompt_type = visual_prompt_type | |
def set_bg_factor(self, bg_factor): | |
self.bg_factor = bg_factor | |
def set_confidence_threshold(self, confidence_threshold): | |
self.confidence_threshold = confidence_threshold | |
def set_mask_threshold(self, mask_threshold): | |
self.mask_threshold = mask_threshold | |
def apply_visual_prompts(self, image, mask): | |
if torch.sum(mask).item() <= 1: | |
return image | |
image_array = np.array(image) | |
img_h = image_array.shape[0] | |
img_w = image_array.shape[1] | |
mask = ( | |
F.interpolate(mask[None][None], size=(img_h, img_w), mode="nearest") | |
.squeeze() | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
mask = (mask > self.mask_threshold).astype(np.uint8) | |
prompted_image = apply_visual_prompts( | |
image_array, mask, self.visual_prompt_type, self.visualize | |
) | |
return prompted_image | |
def get_mask_confidence(self, prompted_images, prompt_text): | |
"""Get the confidene for each mask with visual prompting.""" | |
# get the center, width and height of the mask | |
prompted_tensor = torch.stack( | |
[self.semantic_preprocess(img) for img in prompted_images], dim=0 | |
) | |
prompted_tensor = prompted_tensor.to(self.device) | |
h, w = prompted_tensor.shape[-2:] | |
text_prediction = forward_clip( | |
self.semantic_clip_model, prompted_tensor, prompt_text, h, w | |
) | |
return text_prediction | |
def _filter_texts(self, ori_mask_id, sem_scores, prompt_text): | |
"""Remove false positive masks by score filtering and recall the backbone to get the CAM maps for the filtered texts.""" | |
if not ori_mask_id: | |
max_id = np.argmax(sem_scores) | |
ori_mask_id.append(max_id) | |
filtered_text = [prompt_text[i] for i in ori_mask_id] | |
return filtered_text | |
def _forward_stage(self, ori_img, cam_text, clip_text, semantic_prompt_text): | |
mask_proposals = self.get_mask_proposals(ori_img, cam_text) | |
num_texts = len(cam_text) | |
ori_mask_id = [] | |
sem_scores = torch.zeros((num_texts,), device=self.device).float() | |
prompted_imgs = [ | |
self.apply_visual_prompts(ori_img, cam_map) | |
for cam_map in mask_proposals | |
] | |
text_scores = self.get_mask_confidence(prompted_imgs, semantic_prompt_text) | |
mask_scores = torch.diagonal(text_scores) | |
for mask_idx, mask_score in enumerate(mask_scores): | |
# record mask idx | |
if mask_score > self.confidence_threshold: | |
ori_mask_id.append(mask_idx) | |
sem_scores[mask_idx] = mask_score | |
sem_scores = sem_scores.cpu().detach().numpy() | |
filtered_texts = self._filter_texts(ori_mask_id, sem_scores, clip_text) | |
# if isinstance(ori_img, list): | |
# ori_img = [ori_img[i] for i in ori_mask_id] | |
all_scores = torch.zeros((num_texts,), device=self.device).float() | |
sem_scores = torch.from_numpy(sem_scores).to(self.device) | |
for new_id, ori_id in enumerate(ori_mask_id): | |
if new_id >= len(mask_proposals): | |
# the mask is filtered out. | |
continue | |
all_scores[ori_id] = sem_scores[ori_id] | |
return filtered_texts, all_scores, mask_proposals | |
def _get_save_path(self, text): | |
folder_name = "_".join([t.replace(" ", "_") for t in text]) | |
if len(folder_name) > 20: | |
folder_name = folder_name[:20] | |
output_path = os.path.join(self.save_path, folder_name) | |
sub_output_path = [ | |
os.path.join(output_path, t.replace(" ", "_")) for t in text | |
] | |
return output_path, sub_output_path | |
def get_mask_proposals(self, img, text): | |
if self.seg_mode == "refer": | |
if isinstance(img, list): | |
cam_map_list = [self.mask_generator(i, t)[0] for i, t in zip(img, text)] | |
else: | |
cam_map_list = [self.mask_generator(img, t)[0] for t in text] | |
return torch.cat(cam_map_list, dim=0) | |
elif self.seg_mode == "semantic": | |
return self.mask_generator(img, text)[0] | |
else: | |
raise ValueError( | |
"Unknown segmentation mode. Only refer and semantic segmentation are" | |
" supported." | |
) | |
def _forward_car(self, ori_img, text): | |
if isinstance(text, str): | |
text = [text] | |
_, sub_output_path = self._get_save_path(text) | |
image_array = np.array(ori_img) | |
clip_text = [self.cam_text_template.format(t) for t in text] | |
cam_text = text | |
init_clip_text = clip_text # the text prompts of CLIP is different. | |
semantic_prompt_text = clip_text | |
# Apply semantic prompting augmentation. | |
if self.semantic_templates is not None: | |
semantic_prompt_text = [] | |
for template in self.semantic_templates: | |
templated_text = [template.format(t) for t in text] | |
semantic_prompt_text.append(templated_text) | |
num_positive_last = 0 | |
run = 0 | |
while True: | |
run += 1 | |
cur_texts, all_scores, mask_proposals = self._forward_stage( | |
ori_img, cam_text, clip_text, semantic_prompt_text | |
) | |
if cur_texts: # if there is no text, skip the refinement | |
cam_text = cur_texts | |
clip_text = cur_texts | |
num_positive = (all_scores > 0).sum().item() | |
if num_positive == num_positive_last: | |
# stop the refinement if the number of positive masks | |
# does not change. | |
break | |
num_positive_last = num_positive | |
# Apply densecrf for refinement. | |
# SAM is optional and is applied outside the model. | |
refined_masks = self.post_process( | |
ori_img, | |
mask_proposals, | |
separate=self.seg_mode == "refer", | |
bg_factor=self.bg_factor, | |
) | |
predicted_class_idx = [init_clip_text.index(t) for t in cur_texts] | |
if self.visualize: | |
_ = [ | |
viz_attn( | |
image_array, | |
attn, | |
prefix=sub_output_path[aid], | |
img_name="semantic_mask", | |
) | |
for aid, attn in enumerate(refined_masks) | |
] | |
final_predicted_masks = torch.zeros(len(text), *refined_masks[0].shape) | |
final_all_scores = torch.zeros(len(text)) | |
for idx, mask, score in zip(predicted_class_idx, refined_masks, all_scores): | |
final_predicted_masks[idx] = mask | |
final_all_scores[idx] = score | |
return final_predicted_masks, final_all_scores | |
def forward(self, im_ori, text): | |
# raw_image_np is the padded image input with shape (512, 512, 3) | |
pseudo_masks, conf_scores = self._forward_car(im_ori, text) | |
return pseudo_masks, conf_scores | |