Mask-Adapter / mask_adapter /sam_maskadapter.py
wondervictor's picture
Upload 186 files
ba4c371 verified
raw
history blame
16 kB
import numpy as np
import torch
from torch.nn import functional as F
import cv2
from detectron2.data import MetadataCatalog
from detectron2.structures import BitMasks
from detectron2.utils.visualizer import ColorMode, Visualizer
import open_clip
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from .modeling.meta_arch.mask_adapter_head import build_mask_adapter
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615]
PIXEL_STD = [68.5005327, 66.6321579, 70.32316305]
class OpenVocabVisualizer(Visualizer):
def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None):
super().__init__(img_rgb, metadata, scale, instance_mode)
self.class_names = class_names
def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.6):
"""
Draw semantic segmentation predictions/labels.
Args:
sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
Each value is the integer label of the pixel.
area_threshold (int): segments with less than `area_threshold` are not drawn.
alpha (float): the larger it is, the more opaque the segmentations are.
Returns:
output (VisImage): image object with visualizations.
"""
if isinstance(sem_seg, torch.Tensor):
sem_seg = sem_seg.numpy()
labels, areas = np.unique(sem_seg, return_counts=True)
sorted_idxs = np.argsort(-areas).tolist()
labels = labels[sorted_idxs]
class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes
for label in filter(lambda l: l < len(class_names), labels):
try:
mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
except (AttributeError, IndexError):
mask_color = None
binary_mask = (sem_seg == label).astype(np.uint8)
text = class_names[label]
self.draw_binary_mask(
binary_mask,
color=mask_color,
edge_color=(1.0, 1.0, 240.0 / 255),
text=text,
alpha=alpha,
area_threshold=area_threshold,
)
return self.output
class SAMVisualizationDemo(object):
def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False,):
self.metadata = MetadataCatalog.get(
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.parallel = parallel
self.granularity = granularity
self.sam2 = sam2
self.predictor = SAM2AutomaticMaskGenerator(sam2, points_per_batch=16,
pred_iou_thresh=0.8,
stability_score_thresh=0.7,
crop_n_layers=0,
crop_n_points_downscale_factor=2,
min_mask_region_area=100)
self.clip_model = clip_model
self.mask_adapter = mask_adapter
def extract_features_convnext(self, x):
out = {}
x = self.clip_model.visual.trunk.stem(x)
out['stem'] = x.contiguous() # os4
for i in range(4):
x = self.clip_model.visual.trunk.stages[i](x)
out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32)
x = self.clip_model.visual.trunk.norm_pre(x)
out['clip_vis_dense'] = x.contiguous()
return out
def visual_prediction_forward_convnext(self, x):
batch, num_query, channel = x.shape
x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input
x = self.clip_model.visual.trunk.head(x)
x = self.clip_model.visual.head(x)
return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640
def visual_prediction_forward_convnext_2d(self, x):
clip_vis_dense = self.clip_model.visual.trunk.head.norm(x)
clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1))
clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2)
return clip_vis_dense
def run_on_image(self, ori_image, class_names):
height, width, _ = ori_image.shape
if width > height:
new_width = 896
new_height = int((new_width / width) * height)
else:
new_height = 896
new_width = int((new_height / height) * width)
image = cv2.resize(ori_image, (new_width, new_height))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
visualizer = OpenVocabVisualizer(ori_image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
with torch.no_grad():#, torch.cuda.amp.autocast():
masks = self.predictor.generate(image)
pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))]
pred_masks = np.row_stack(pred_masks)
pred_masks = BitMasks(pred_masks)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1)
pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1)
image = (image - pixel_mean) / pixel_std
image = image.unsqueeze(0)
if len(class_names) == 1:
class_names.append('others')
txts = [f'a photo of {cls_name}' for cls_name in class_names]
text = open_clip.tokenize(txts)
with torch.no_grad():
self.clip_model.cuda()
text_features = self.clip_model.encode_text(text.cuda())
text_features /= text_features.norm(dim=-1, keepdim=True)
features = self.extract_features_convnext(image.cuda().float())
clip_feature = features['clip_vis_dense']
clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature)
semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda())
maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:],
mode='bilinear', align_corners=False)
B, C = clip_feature.size(0),clip_feature.size(1)
N = maps_for_pooling.size(1)
num_instances = N // 16
maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1)
pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1))
pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature)
pooled_clip_feature = (pooled_clip_feature.reshape(B,num_instances, 16, -1).mean(dim=-2).contiguous())
class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1)
class_preds = class_preds.squeeze(0)
select_cls = torch.zeros_like(class_preds)
max_scores, select_mask = torch.max(class_preds, dim=0)
if len(class_names) == 2 and class_names[-1] == 'others':
select_mask = select_mask[:-1]
if self.granularity < 1:
thr_scores = max_scores * self.granularity
select_mask = []
if len(class_names) == 2 and class_names[-1] == 'others':
thr_scores = thr_scores[:-1]
for i, thr in enumerate(thr_scores):
cls_pred = class_preds[:,i]
locs = torch.where(cls_pred > thr)
select_mask.extend(locs[0].tolist())
for idx in select_mask:
select_cls[idx] = class_preds[idx]
semseg = torch.einsum("qc,qhw->chw", select_cls.float(), pred_masks.tensor.float().cuda())
r = semseg
blank_area = (r[0] == 0)
pred_mask = r.argmax(dim=0).to('cpu')
pred_mask[blank_area] = 255
pred_mask = np.array(pred_mask, dtype=int)
pred_mask = cv2.resize(pred_mask, (width, height), interpolation=cv2.INTER_NEAREST)
vis_output = visualizer.draw_sem_seg(
pred_mask
)
return None, vis_output
class SAMPointVisualizationDemo(object):
def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False):
self.metadata = MetadataCatalog.get(
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.parallel = parallel
self.granularity = granularity
self.sam2 = sam2
self.predictor = SAM2ImagePredictor(sam2)
self.clip_model = clip_model
self.mask_adapter = mask_adapter
from .data.datasets import openseg_classes
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
#COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng()
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
#print(coco_metadata)
lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
self.class_names = thing_classes + stuff_classes + lvis_classes
self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).to("cuda")
self.class_names = self._load_class_names()
def _load_class_names(self):
from .data.datasets import openseg_classes
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
return thing_classes + stuff_classes + lvis_classes
def extract_features_convnext(self, x):
out = {}
x = self.clip_model.visual.trunk.stem(x)
out['stem'] = x.contiguous() # os4
for i in range(4):
x = self.clip_model.visual.trunk.stages[i](x)
out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32)
x = self.clip_model.visual.trunk.norm_pre(x)
out['clip_vis_dense'] = x.contiguous()
return out
def visual_prediction_forward_convnext(self, x):
batch, num_query, channel = x.shape
x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input
x = self.clip_model.visual.trunk.head(x)
x = self.clip_model.visual.head(x)
return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640
def visual_prediction_forward_convnext_2d(self, x):
clip_vis_dense = self.clip_model.visual.trunk.head.norm(x)
clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1))
clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2)
return clip_vis_dense
def run_on_image_with_points(self, ori_image, points):
height, width, _ = ori_image.shape
image = ori_image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
input_point = np.array(points)
input_label = np.array([1])
with torch.no_grad():
self.predictor.set_image(image)
masks, _, _ = self.predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False)
pred_masks = BitMasks(masks)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1)
pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1)
image = (image - pixel_mean) / pixel_std
image = image.unsqueeze(0)
# txts = [f'a photo of {cls_name}' for cls_name in self.class_names]
# text = open_clip.tokenize(txts)
with torch.no_grad():
self.clip_model.cuda()
# text_features = self.clip_model.encode_text(text.cuda())
# text_features /= text_features.norm(dim=-1, keepdim=True)
#np.save("/home/yongkangli/Mask-Adapter/text_embedding/lvis_coco_text_embedding.npy", text_features.cpu().numpy())
text_features = self.text_embedding
features = self.extract_features_convnext(image.cuda().float())
clip_feature = features['clip_vis_dense']
clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature)
semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda())
maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False)
B, C = clip_feature.size(0), clip_feature.size(1)
N = maps_for_pooling.size(1)
num_instances = N // 16
maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1)
pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1))
pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature)
pooled_clip_feature = (pooled_clip_feature.reshape(B, num_instances, 16, -1).mean(dim=-2).contiguous())
class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1)
class_preds = class_preds.squeeze(0)
# Resize mask to match original image size
pred_mask = cv2.resize(masks.squeeze(0), (width, height), interpolation=cv2.INTER_NEAREST) # Resize mask to match original image size
# Create an overlay for the mask with a transparent background (using alpha transparency)
overlay = ori_image.copy()
mask_colored = np.zeros_like(ori_image)
mask_colored[pred_mask == 1] = [234, 103, 112] # Green color for the mask
# Apply the mask with transparency (alpha blending)
alpha = 0.5
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
# Draw boundary (contours) on the overlay
contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White boundary
# Add label based on the class with the highest score
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
label = f"{self.class_names[max_score_idx.item()]}: {max_scores.item():.2f}"
# Dynamically place the label near the clicked point
text_x = min(width - 200, points[0][0] + 20) # Add some offset from the point
text_y = min(height - 30, points[0][1] + 20) # Ensure the text does not go out of bounds
# Put text near the point
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
return None, Image.fromarray(overlay)