Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Calculate CAM with CLIP model.""" | |
import warnings | |
import clip | |
import cv2 | |
import numpy as np | |
import torch | |
# pylint: disable=g-importing-member | |
# pylint: disable=g-bad-import-order | |
from modeling.model.cam import CAM | |
from modeling.model.cam import scale_cam_image | |
from modeling.model.utils import img_ms_and_flip | |
from modeling.model.utils import reshape_transform | |
from modeling.model.utils import scoremap2bbox | |
warnings.filterwarnings("ignore") | |
class ClipOutputTarget: | |
def __init__(self, category): | |
self.category = category | |
def __call__(self, model_output): | |
if len(model_output.shape) == 1: | |
return model_output[self.category] | |
return model_output[:, self.category] | |
def zeroshot_classifier(classnames, templates, model, device): | |
"""Zeroshot classifier.""" | |
with torch.no_grad(): | |
zeroshot_weights = [] | |
for classname in classnames: | |
if templates is None: | |
texts = [classname] | |
else: | |
# format with class | |
texts = [template.format(classname) for template in templates] | |
texts = clip.tokenize(texts).to(device) # tokenize | |
class_embeddings = model.encode_text(texts) # embed with text encoder | |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
class_embedding = class_embeddings.mean(dim=0) | |
class_embedding /= class_embedding.norm() | |
zeroshot_weights.append(class_embedding) | |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) | |
return zeroshot_weights.t() | |
class CLIPCAM: | |
"""Generate CAM with CLIP model.""" | |
def __init__( | |
self, | |
clip_model, | |
device, | |
text_template=None, | |
threshold=0.4, | |
bg_cls=None, | |
): | |
self.device = device | |
self.clip_model = clip_model.to(device) | |
self.text_template = text_template | |
self.threshold = threshold | |
self.stride = self.clip_model.visual.patch_size | |
# if self.dataset_name == 'voc' else BACKGROUND_CATEGORY_COCO | |
self.bg_cls = bg_cls | |
self.bg_text_features = None | |
if self.bg_cls is not None: | |
self.bg_text_features = zeroshot_classifier( | |
self.bg_cls, | |
("a clean origami {}.",), | |
self.clip_model, | |
self.device, | |
).to(self.device) | |
self.target_layers = [self.clip_model.visual.transformer.resblocks[-1].ln_1] | |
self.cam = CAM( | |
model=self.clip_model, | |
target_layers=self.target_layers, | |
reshape_transform=reshape_transform, | |
use_cuda="cuda" in device, | |
stride=self.stride, | |
) | |
def set_bg_cls(self, bg_cls): | |
# if len(bg_cls) == 0: | |
if not bg_cls: | |
self.bg_cls = None | |
self.bg_text_features = None | |
else: | |
self.bg_cls = bg_cls | |
self.bg_text_features = zeroshot_classifier( | |
self.bg_cls, | |
("a clean origami {}.",), | |
self.clip_model, | |
self.device, | |
).to(self.device) | |
def __call__(self, ori_img, text, scale=1.0): | |
"""Get CAM masks and features. | |
Args: | |
ori_img(Image): image to be searched. | |
text (str): text to be searched. | |
scale (float): image scale. | |
Returns: | |
CAM masks and features. | |
""" | |
ori_width = ori_img.size[0] | |
ori_height = ori_img.size[1] | |
if isinstance(text, str): | |
text = [text] | |
# convert image to bgr channel | |
ms_imgs = img_ms_and_flip(ori_img, ori_height, ori_width, scales=[scale]) | |
image = ms_imgs[0] | |
image = image.unsqueeze(0) | |
h, w = image.shape[-2], image.shape[-1] | |
image = image.to(self.device) | |
image_features, attn_weight_list = self.clip_model.encode_image(image, h, w) | |
highres_cam_to_save = [] | |
refined_cam_to_save = [] | |
# keys = [] | |
# [bg_id_for_each_image[im_idx]].to(device_id) | |
bg_features_temp = None | |
if self.bg_text_features is not None: | |
bg_features_temp = self.bg_text_features.to(self.device) | |
fg_features_temp = zeroshot_classifier( | |
text, self.text_template, self.clip_model, self.device | |
).to(self.device) | |
if bg_features_temp is None: | |
text_features_temp = fg_features_temp | |
else: | |
text_features_temp = torch.cat( | |
[fg_features_temp, bg_features_temp], dim=0 | |
) | |
input_tensor = [ | |
image_features, | |
text_features_temp.to(self.device), | |
h, | |
w, | |
] | |
# for idx, label in enumerate(label_list): | |
# keys.append(new_class_names.index(label)) | |
for idx, _ in enumerate(text): | |
targets = [ClipOutputTarget(idx)] | |
# torch.cuda.empty_cache() | |
grayscale_cam, _, attn_weight_last = self.cam( | |
input_tensor=input_tensor, targets=targets, target_size=None | |
) # (ori_width, ori_height)) | |
grayscale_cam = grayscale_cam[0, :] | |
if grayscale_cam.max() == 0: | |
input_tensor_fg = ( | |
image_features, | |
fg_features_temp.to(self.device), | |
h, | |
w, | |
) | |
grayscale_cam, _, attn_weight_last = self.cam( | |
input_tensor=input_tensor_fg, | |
targets=targets, | |
target_size=None, | |
) | |
grayscale_cam = grayscale_cam[0, :] | |
grayscale_cam_highres = cv2.resize(grayscale_cam, (ori_width, ori_height)) | |
highres_cam_to_save.append(torch.tensor(grayscale_cam_highres)) | |
if idx == 0: | |
attn_weight_list.append(attn_weight_last) | |
attn_weight = [ | |
aw[:, 1:, 1:] for aw in attn_weight_list | |
] # (b, hxw, hxw) | |
attn_weight = torch.stack(attn_weight, dim=0)[-8:] | |
attn_weight = torch.mean(attn_weight, dim=0) | |
attn_weight = attn_weight[0].cpu().detach() | |
attn_weight = attn_weight.float() | |
box, cnt = scoremap2bbox( | |
scoremap=grayscale_cam, | |
threshold=self.threshold, | |
multi_contour_eval=True, | |
) | |
aff_mask = torch.zeros((grayscale_cam.shape[0], grayscale_cam.shape[1])) | |
for i_ in range(cnt): | |
x0_, y0_, x1_, y1_ = box[i_] | |
aff_mask[y0_:y1_, x0_:x1_] = 1 | |
aff_mask = aff_mask.view( | |
1, grayscale_cam.shape[0] * grayscale_cam.shape[1] | |
) | |
aff_mat = attn_weight | |
trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True) | |
trans_mat = trans_mat / torch.sum(trans_mat, dim=1, keepdim=True) | |
for _ in range(2): | |
trans_mat = trans_mat / torch.sum(trans_mat, dim=0, keepdim=True) | |
trans_mat = trans_mat / torch.sum(trans_mat, dim=1, keepdim=True) | |
trans_mat = (trans_mat + trans_mat.transpose(1, 0)) / 2 | |
# This is copied from CLIP-ES | |
for _ in range(1): | |
trans_mat = torch.matmul(trans_mat, trans_mat) | |
trans_mat = trans_mat * aff_mask | |
cam_to_refine = torch.FloatTensor(grayscale_cam) | |
cam_to_refine = cam_to_refine.view(-1, 1) | |
# (n,n) * (n,1)->(n,1) | |
cam_refined = torch.matmul(trans_mat, cam_to_refine).reshape( | |
h // self.stride, w // self.stride | |
) | |
cam_refined = cam_refined.cpu().numpy().astype(np.float32) | |
cam_refined_highres = scale_cam_image( | |
[cam_refined], (ori_width, ori_height) | |
)[0] | |
refined_cam_to_save.append(torch.tensor(cam_refined_highres)) | |
# post process the cam map | |
# label = process(raw_image, refined_cam, postprocessor) | |
# vis_img = vis_mask(np.asarray(raw_image), label, [0, 255, 0]) | |
# vis_img.save(f'clip_es_crf_{idx}.jpg') | |
# keys = torch.tensor(keys) | |
# cam_all_scales.append(torch.stack(cam_to_save,dim=0)) | |
cam_masks = torch.stack(refined_cam_to_save, dim=0) | |
return cam_masks.to(self.device), fg_features_temp.to(self.device) | |