import numpy as np import torch from torch.nn import functional as F import cv2 from detectron2.data import MetadataCatalog from detectron2.structures import BitMasks from detectron2.utils.visualizer import ColorMode, Visualizer import open_clip from sam2.build_sam import build_sam2 from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from .modeling.meta_arch.mask_adapter_head import build_mask_adapter from sam2.sam2_image_predictor import SAM2ImagePredictor from PIL import Image PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615] PIXEL_STD = [68.5005327, 66.6321579, 70.32316305] class OpenVocabVisualizer(Visualizer): def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None): super().__init__(img_rgb, metadata, scale, instance_mode) self.class_names = class_names def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.6): """ Draw semantic segmentation predictions/labels. Args: sem_seg (Tensor or ndarray): the segmentation of shape (H, W). Each value is the integer label of the pixel. area_threshold (int): segments with less than `area_threshold` are not drawn. alpha (float): the larger it is, the more opaque the segmentations are. Returns: output (VisImage): image object with visualizations. """ if isinstance(sem_seg, torch.Tensor): sem_seg = sem_seg.numpy() labels, areas = np.unique(sem_seg, return_counts=True) sorted_idxs = np.argsort(-areas).tolist() labels = labels[sorted_idxs] class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes for label in filter(lambda l: l < len(class_names), labels): try: mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] except (AttributeError, IndexError): mask_color = None binary_mask = (sem_seg == label).astype(np.uint8) text = class_names[label] self.draw_binary_mask( binary_mask, color=mask_color, edge_color=(1.0, 1.0, 240.0 / 255), text=text, alpha=alpha, area_threshold=area_threshold, ) return self.output class SAMVisualizationDemo(object): def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False,): self.metadata = MetadataCatalog.get( cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" ) self.cpu_device = torch.device("cpu") self.instance_mode = instance_mode self.parallel = parallel self.granularity = granularity self.sam2 = sam2 self.predictor = SAM2AutomaticMaskGenerator(sam2, points_per_batch=16, pred_iou_thresh=0.8, stability_score_thresh=0.7, crop_n_layers=0, crop_n_points_downscale_factor=2, min_mask_region_area=100) self.clip_model = clip_model self.mask_adapter = mask_adapter def extract_features_convnext(self, x): out = {} x = self.clip_model.visual.trunk.stem(x) out['stem'] = x.contiguous() # os4 for i in range(4): x = self.clip_model.visual.trunk.stages[i](x) out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32) x = self.clip_model.visual.trunk.norm_pre(x) out['clip_vis_dense'] = x.contiguous() return out def visual_prediction_forward_convnext(self, x): batch, num_query, channel = x.shape x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input x = self.clip_model.visual.trunk.head(x) x = self.clip_model.visual.head(x) return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640 def visual_prediction_forward_convnext_2d(self, x): clip_vis_dense = self.clip_model.visual.trunk.head.norm(x) clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1)) clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2) return clip_vis_dense def run_on_image(self, ori_image, class_names): height, width, _ = ori_image.shape if width > height: new_width = 896 new_height = int((new_width / width) * height) else: new_height = 896 new_width = int((new_height / height) * width) image = cv2.resize(ori_image, (new_width, new_height)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB) visualizer = OpenVocabVisualizer(ori_image, self.metadata, instance_mode=self.instance_mode, class_names=class_names) with torch.no_grad():#, torch.cuda.amp.autocast(): masks = self.predictor.generate(image) pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))] pred_masks = np.row_stack(pred_masks) pred_masks = BitMasks(pred_masks) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1) pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1) image = (image - pixel_mean) / pixel_std image = image.unsqueeze(0) if len(class_names) == 1: class_names.append('others') txts = [f'a photo of {cls_name}' for cls_name in class_names] text = open_clip.tokenize(txts) with torch.no_grad(): self.clip_model.cuda() text_features = self.clip_model.encode_text(text.cuda()) text_features /= text_features.norm(dim=-1, keepdim=True) features = self.extract_features_convnext(image.cuda().float()) clip_feature = features['clip_vis_dense'] clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature) semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda()) maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False) B, C = clip_feature.size(0),clip_feature.size(1) N = maps_for_pooling.size(1) num_instances = N // 16 maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1) pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1)) pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature) pooled_clip_feature = (pooled_clip_feature.reshape(B,num_instances, 16, -1).mean(dim=-2).contiguous()) class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1) class_preds = class_preds.squeeze(0) select_cls = torch.zeros_like(class_preds) max_scores, select_mask = torch.max(class_preds, dim=0) if len(class_names) == 2 and class_names[-1] == 'others': select_mask = select_mask[:-1] if self.granularity < 1: thr_scores = max_scores * self.granularity select_mask = [] if len(class_names) == 2 and class_names[-1] == 'others': thr_scores = thr_scores[:-1] for i, thr in enumerate(thr_scores): cls_pred = class_preds[:,i] locs = torch.where(cls_pred > thr) select_mask.extend(locs[0].tolist()) for idx in select_mask: select_cls[idx] = class_preds[idx] semseg = torch.einsum("qc,qhw->chw", select_cls.float(), pred_masks.tensor.float().cuda()) r = semseg blank_area = (r[0] == 0) pred_mask = r.argmax(dim=0).to('cpu') pred_mask[blank_area] = 255 pred_mask = np.array(pred_mask, dtype=int) pred_mask = cv2.resize(pred_mask, (width, height), interpolation=cv2.INTER_NEAREST) vis_output = visualizer.draw_sem_seg( pred_mask ) return None, vis_output class SAMPointVisualizationDemo(object): def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False): self.metadata = MetadataCatalog.get( cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" ) self.cpu_device = torch.device("cpu") self.instance_mode = instance_mode self.parallel = parallel self.granularity = granularity self.sam2 = sam2 self.predictor = SAM2ImagePredictor(sam2) self.clip_model = clip_model self.mask_adapter = mask_adapter from .data.datasets import openseg_classes COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng() #COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng() thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1] stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan] #print(coco_metadata) lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines() lvis_classes = [x[x.find(':')+1:] for x in lvis_classes] self.class_names = thing_classes + stuff_classes + lvis_classes self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).to("cuda") self.class_names = self._load_class_names() def _load_class_names(self): from .data.datasets import openseg_classes COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng() thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1] stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan] lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines() lvis_classes = [x[x.find(':')+1:] for x in lvis_classes] return thing_classes + stuff_classes + lvis_classes def extract_features_convnext(self, x): out = {} x = self.clip_model.visual.trunk.stem(x) out['stem'] = x.contiguous() # os4 for i in range(4): x = self.clip_model.visual.trunk.stages[i](x) out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32) x = self.clip_model.visual.trunk.norm_pre(x) out['clip_vis_dense'] = x.contiguous() return out def visual_prediction_forward_convnext(self, x): batch, num_query, channel = x.shape x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input x = self.clip_model.visual.trunk.head(x) x = self.clip_model.visual.head(x) return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640 def visual_prediction_forward_convnext_2d(self, x): clip_vis_dense = self.clip_model.visual.trunk.head.norm(x) clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1)) clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2) return clip_vis_dense def run_on_image_with_points(self, ori_image, points): height, width, _ = ori_image.shape image = ori_image image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB) input_point = np.array(points) input_label = np.array([1]) with torch.no_grad(): self.predictor.set_image(image) masks, _, _ = self.predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False) pred_masks = BitMasks(masks) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1) pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1) image = (image - pixel_mean) / pixel_std image = image.unsqueeze(0) # txts = [f'a photo of {cls_name}' for cls_name in self.class_names] # text = open_clip.tokenize(txts) with torch.no_grad(): self.clip_model.cuda() # text_features = self.clip_model.encode_text(text.cuda()) # text_features /= text_features.norm(dim=-1, keepdim=True) #np.save("/home/yongkangli/Mask-Adapter/text_embedding/lvis_coco_text_embedding.npy", text_features.cpu().numpy()) text_features = self.text_embedding features = self.extract_features_convnext(image.cuda().float()) clip_feature = features['clip_vis_dense'] clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature) semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda()) maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False) B, C = clip_feature.size(0), clip_feature.size(1) N = maps_for_pooling.size(1) num_instances = N // 16 maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1) pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1)) pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature) pooled_clip_feature = (pooled_clip_feature.reshape(B, num_instances, 16, -1).mean(dim=-2).contiguous()) class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1) class_preds = class_preds.squeeze(0) # Resize mask to match original image size pred_mask = cv2.resize(masks.squeeze(0), (width, height), interpolation=cv2.INTER_NEAREST) # Resize mask to match original image size # Create an overlay for the mask with a transparent background (using alpha transparency) overlay = ori_image.copy() mask_colored = np.zeros_like(ori_image) mask_colored[pred_mask == 1] = [234, 103, 112] # Green color for the mask # Apply the mask with transparency (alpha blending) alpha = 0.5 cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay) # Draw boundary (contours) on the overlay contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White boundary # Add label based on the class with the highest score max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions label = f"{self.class_names[max_score_idx.item()]}: {max_scores.item():.2f}" # Dynamically place the label near the clicked point text_x = min(width - 200, points[0][0] + 20) # Add some offset from the point text_y = min(height - 30, points[0][1] + 20) # Ensure the text does not go out of bounds # Put text near the point cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) return None, Image.fromarray(overlay)