Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
from torch.nn import functional as F | |
import cv2 | |
from detectron2.data import MetadataCatalog | |
from detectron2.structures import BitMasks | |
from detectron2.utils.visualizer import ColorMode, Visualizer | |
import open_clip | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from .modeling.meta_arch.mask_adapter_head import build_mask_adapter | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
from PIL import Image | |
PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615] | |
PIXEL_STD = [68.5005327, 66.6321579, 70.32316305] | |
class OpenVocabVisualizer(Visualizer): | |
def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None): | |
super().__init__(img_rgb, metadata, scale, instance_mode) | |
self.class_names = class_names | |
def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.6): | |
""" | |
Draw semantic segmentation predictions/labels. | |
Args: | |
sem_seg (Tensor or ndarray): the segmentation of shape (H, W). | |
Each value is the integer label of the pixel. | |
area_threshold (int): segments with less than `area_threshold` are not drawn. | |
alpha (float): the larger it is, the more opaque the segmentations are. | |
Returns: | |
output (VisImage): image object with visualizations. | |
""" | |
if isinstance(sem_seg, torch.Tensor): | |
sem_seg = sem_seg.numpy() | |
labels, areas = np.unique(sem_seg, return_counts=True) | |
sorted_idxs = np.argsort(-areas).tolist() | |
labels = labels[sorted_idxs] | |
class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes | |
for label in filter(lambda l: l < len(class_names), labels): | |
try: | |
mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] | |
except (AttributeError, IndexError): | |
mask_color = None | |
binary_mask = (sem_seg == label).astype(np.uint8) | |
text = class_names[label] | |
self.draw_binary_mask( | |
binary_mask, | |
color=mask_color, | |
edge_color=(1.0, 1.0, 240.0 / 255), | |
text=text, | |
alpha=alpha, | |
area_threshold=area_threshold, | |
) | |
return self.output | |
class SAMVisualizationDemo(object): | |
def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False,): | |
self.metadata = MetadataCatalog.get( | |
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" | |
) | |
self.cpu_device = torch.device("cpu") | |
self.instance_mode = instance_mode | |
self.parallel = parallel | |
self.granularity = granularity | |
self.sam2 = sam2 | |
self.predictor = SAM2AutomaticMaskGenerator(sam2, points_per_batch=16, | |
pred_iou_thresh=0.8, | |
stability_score_thresh=0.7, | |
crop_n_layers=0, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=100) | |
self.clip_model = clip_model | |
self.mask_adapter = mask_adapter | |
def extract_features_convnext(self, x): | |
out = {} | |
x = self.clip_model.visual.trunk.stem(x) | |
out['stem'] = x.contiguous() # os4 | |
for i in range(4): | |
x = self.clip_model.visual.trunk.stages[i](x) | |
out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32) | |
x = self.clip_model.visual.trunk.norm_pre(x) | |
out['clip_vis_dense'] = x.contiguous() | |
return out | |
def visual_prediction_forward_convnext(self, x): | |
batch, num_query, channel = x.shape | |
x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input | |
x = self.clip_model.visual.trunk.head(x) | |
x = self.clip_model.visual.head(x) | |
return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640 | |
def visual_prediction_forward_convnext_2d(self, x): | |
clip_vis_dense = self.clip_model.visual.trunk.head.norm(x) | |
clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1)) | |
clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2) | |
return clip_vis_dense | |
def run_on_image(self, ori_image, class_names): | |
height, width, _ = ori_image.shape | |
if width > height: | |
new_width = 896 | |
new_height = int((new_width / width) * height) | |
else: | |
new_height = 896 | |
new_width = int((new_height / height) * width) | |
image = cv2.resize(ori_image, (new_width, new_height)) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB) | |
visualizer = OpenVocabVisualizer(ori_image, self.metadata, instance_mode=self.instance_mode, class_names=class_names) | |
with torch.no_grad():#, torch.cuda.amp.autocast(): | |
masks = self.predictor.generate(image) | |
pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))] | |
pred_masks = np.row_stack(pred_masks) | |
pred_masks = BitMasks(pred_masks) | |
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1) | |
pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1) | |
image = (image - pixel_mean) / pixel_std | |
image = image.unsqueeze(0) | |
if len(class_names) == 1: | |
class_names.append('others') | |
txts = [f'a photo of {cls_name}' for cls_name in class_names] | |
text = open_clip.tokenize(txts) | |
with torch.no_grad(): | |
self.clip_model.cuda() | |
text_features = self.clip_model.encode_text(text.cuda()) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
features = self.extract_features_convnext(image.cuda().float()) | |
clip_feature = features['clip_vis_dense'] | |
clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature) | |
semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda()) | |
maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], | |
mode='bilinear', align_corners=False) | |
B, C = clip_feature.size(0),clip_feature.size(1) | |
N = maps_for_pooling.size(1) | |
num_instances = N // 16 | |
maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1) | |
pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1)) | |
pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature) | |
pooled_clip_feature = (pooled_clip_feature.reshape(B,num_instances, 16, -1).mean(dim=-2).contiguous()) | |
class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1) | |
class_preds = class_preds.squeeze(0) | |
select_cls = torch.zeros_like(class_preds) | |
max_scores, select_mask = torch.max(class_preds, dim=0) | |
if len(class_names) == 2 and class_names[-1] == 'others': | |
select_mask = select_mask[:-1] | |
if self.granularity < 1: | |
thr_scores = max_scores * self.granularity | |
select_mask = [] | |
if len(class_names) == 2 and class_names[-1] == 'others': | |
thr_scores = thr_scores[:-1] | |
for i, thr in enumerate(thr_scores): | |
cls_pred = class_preds[:,i] | |
locs = torch.where(cls_pred > thr) | |
select_mask.extend(locs[0].tolist()) | |
for idx in select_mask: | |
select_cls[idx] = class_preds[idx] | |
semseg = torch.einsum("qc,qhw->chw", select_cls.float(), pred_masks.tensor.float().cuda()) | |
r = semseg | |
blank_area = (r[0] == 0) | |
pred_mask = r.argmax(dim=0).to('cpu') | |
pred_mask[blank_area] = 255 | |
pred_mask = np.array(pred_mask, dtype=int) | |
pred_mask = cv2.resize(pred_mask, (width, height), interpolation=cv2.INTER_NEAREST) | |
vis_output = visualizer.draw_sem_seg( | |
pred_mask | |
) | |
return None, vis_output | |
class SAMPointVisualizationDemo(object): | |
def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False): | |
self.metadata = MetadataCatalog.get( | |
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" | |
) | |
self.cpu_device = torch.device("cpu") | |
self.instance_mode = instance_mode | |
self.parallel = parallel | |
self.granularity = granularity | |
self.sam2 = sam2 | |
self.predictor = SAM2ImagePredictor(sam2) | |
self.clip_model = clip_model | |
self.mask_adapter = mask_adapter | |
from .data.datasets import openseg_classes | |
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng() | |
#COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng() | |
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1] | |
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan] | |
#print(coco_metadata) | |
lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines() | |
lvis_classes = [x[x.find(':')+1:] for x in lvis_classes] | |
self.class_names = thing_classes + stuff_classes + lvis_classes | |
self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).to("cuda") | |
self.class_names = self._load_class_names() | |
def _load_class_names(self): | |
from .data.datasets import openseg_classes | |
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng() | |
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1] | |
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan] | |
lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines() | |
lvis_classes = [x[x.find(':')+1:] for x in lvis_classes] | |
return thing_classes + stuff_classes + lvis_classes | |
def extract_features_convnext(self, x): | |
out = {} | |
x = self.clip_model.visual.trunk.stem(x) | |
out['stem'] = x.contiguous() # os4 | |
for i in range(4): | |
x = self.clip_model.visual.trunk.stages[i](x) | |
out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32) | |
x = self.clip_model.visual.trunk.norm_pre(x) | |
out['clip_vis_dense'] = x.contiguous() | |
return out | |
def visual_prediction_forward_convnext(self, x): | |
batch, num_query, channel = x.shape | |
x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input | |
x = self.clip_model.visual.trunk.head(x) | |
x = self.clip_model.visual.head(x) | |
return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640 | |
def visual_prediction_forward_convnext_2d(self, x): | |
clip_vis_dense = self.clip_model.visual.trunk.head.norm(x) | |
clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1)) | |
clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2) | |
return clip_vis_dense | |
def run_on_image_with_points(self, ori_image, points): | |
height, width, _ = ori_image.shape | |
image = ori_image | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB) | |
input_point = np.array(points) | |
input_label = np.array([1]) | |
with torch.no_grad(): | |
self.predictor.set_image(image) | |
masks, _, _ = self.predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False) | |
pred_masks = BitMasks(masks) | |
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1) | |
pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1) | |
image = (image - pixel_mean) / pixel_std | |
image = image.unsqueeze(0) | |
# txts = [f'a photo of {cls_name}' for cls_name in self.class_names] | |
# text = open_clip.tokenize(txts) | |
with torch.no_grad(): | |
self.clip_model.cuda() | |
# text_features = self.clip_model.encode_text(text.cuda()) | |
# text_features /= text_features.norm(dim=-1, keepdim=True) | |
#np.save("/home/yongkangli/Mask-Adapter/text_embedding/lvis_coco_text_embedding.npy", text_features.cpu().numpy()) | |
text_features = self.text_embedding | |
features = self.extract_features_convnext(image.cuda().float()) | |
clip_feature = features['clip_vis_dense'] | |
clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature) | |
semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda()) | |
maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False) | |
B, C = clip_feature.size(0), clip_feature.size(1) | |
N = maps_for_pooling.size(1) | |
num_instances = N // 16 | |
maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1) | |
pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1)) | |
pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature) | |
pooled_clip_feature = (pooled_clip_feature.reshape(B, num_instances, 16, -1).mean(dim=-2).contiguous()) | |
class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1) | |
class_preds = class_preds.squeeze(0) | |
# Resize mask to match original image size | |
pred_mask = cv2.resize(masks.squeeze(0), (width, height), interpolation=cv2.INTER_NEAREST) # Resize mask to match original image size | |
# Create an overlay for the mask with a transparent background (using alpha transparency) | |
overlay = ori_image.copy() | |
mask_colored = np.zeros_like(ori_image) | |
mask_colored[pred_mask == 1] = [234, 103, 112] # Green color for the mask | |
# Apply the mask with transparency (alpha blending) | |
alpha = 0.5 | |
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay) | |
# Draw boundary (contours) on the overlay | |
contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White boundary | |
# Add label based on the class with the highest score | |
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions | |
label = f"{self.class_names[max_score_idx.item()]}: {max_scores.item():.2f}" | |
# Dynamically place the label near the clicked point | |
text_x = min(width - 200, points[0][0] + 20) # Add some offset from the point | |
text_y = min(height - 30, points[0][1] + 20) # Ensure the text does not go out of bounds | |
# Put text near the point | |
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
return None, Image.fromarray(overlay) |