# -------------------------------------------------------- # Semantic-SAM: Segment and Recognize Anything at Any Granularity # Copyright (c) 2023 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Hao Zhang (hzhangcx@connect.ust.hk) # -------------------------------------------------------- import torch import torch.nn.functional as F import numpy as np from torchvision import transforms from task_adapter.utils.visualizer import Visualizer from typing import Tuple from PIL import Image from detectron2.data import MetadataCatalog import matplotlib.pyplot as plt import cv2 import io from .automatic_mask_generator import SeemAutomaticMaskGenerator metadata = MetadataCatalog.get('coco_2017_train_panoptic') from segment_anything.utils.amg import ( MaskData, area_from_rle, batch_iterator, batched_mask_to_box, box_xyxy_to_xywh, build_all_layer_point_grids, calculate_stability_score, coco_encode_rle, generate_crop_boxes, is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, uncrop_masks, uncrop_points, ) def inference_seem_interactive(model, image, spatial_masks, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']): t = [] t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) transform1 = transforms.Compose(t) image_ori = transform1(image) image_ori = np.asarray(image_ori) images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() orig_size = images.shape[-2:] orig_h, orig_w = orig_size crop_box = [0,0,orig_w,orig_h] data = {"image": images, "height": orig_h, "width": orig_w} spatial_masks = spatial_masks[:, None].float().cuda() spatial_masks = F.interpolate(spatial_masks, size=(orig_h, orig_w), mode='bicubic', align_corners=False) > 0 data['spatial_query'] = {'rand_shape': spatial_masks} model.model.metadata = metadata masks, _ = model.model.evaluate_demo([data]) masks = masks > 0.0 iou_preds = torch.ones(masks.shape[0], dtype=torch.float32) points = torch.zeros((masks.shape[0], 2), dtype=torch.float32) mask_data = MaskData( masks=masks, iou_preds=iou_preds, points=points, ) mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32) del masks mask_data["boxes"] = batched_mask_to_box(mask_data["masks"]) mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))]) # Compress to RLE mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w) mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"]) del mask_data["masks"] mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] # Write mask records outputs = [] for idx in range(len(mask_data["segmentations"])): ann = { "segmentation": mask_data["segmentations"][idx], "area": area_from_rle(mask_data["rles"][idx]), "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), "predicted_iou": mask_data["iou_preds"][idx].item(), "point_coords": [mask_data["points"][idx].tolist()], "stability_score": mask_data["stability_score"][idx].item(), "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), } outputs.append(ann) from task_adapter.utils.visualizer import Visualizer visual = Visualizer(image_ori, metadata=metadata) sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True) label = 1 # for ann in sorted_anns: # mask = ann['segmentation'] # color_mask = np.random.random((1, 3)).tolist()[0] # # color_mask = [int(c*255) for c in color_mask] # demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode) # label += 1 # im = demo.get_image() mask_map = np.zeros(image_ori.shape, dtype=np.uint8) for i, ann in enumerate(sorted_anns): mask = ann['segmentation'] color_mask = np.random.random((1, 3)).tolist()[0] # color_mask = [int(c*255) for c in color_mask] demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode) # assign the mask to the mask_map mask_map[mask == 1] = label label += 1 im = demo.get_image() # fig=plt.figure(figsize=(10, 10)) # plt.imshow(image_ori) # show_anns(outputs) # fig.canvas.draw() # im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) return im, sorted_anns def remove_small_regions( mask: np.ndarray, area_thresh: float, mode: str ) -> Tuple[np.ndarray, bool]: """ Removes small disconnected regions and holes in a mask. Returns the mask and an indicator of if the mask has been modified. """ import cv2 # type: ignore assert mode in ["holes", "islands"] correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) sizes = stats[:, -1][1:] # Row 0 is background label small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] if len(small_regions) == 0: return mask, False fill_labels = [0] + small_regions if not correct_holes: fill_labels = [i for i in range(n_labels) if i not in fill_labels] # If every region is below threshold, keep largest if len(fill_labels) == 0: fill_labels = [int(np.argmax(sizes)) + 1] mask = np.isin(regions, fill_labels) return mask, True def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35)))