Spaces:
Running
on
Zero
Running
on
Zero
# Adapted from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb | |
import argparse | |
import os | |
import random | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import cv2 | |
import numpy as np | |
import requests | |
import torch | |
from PIL import Image | |
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline | |
def create_palette(): | |
# Define a palette with 24 colors for labels 0-23 (example colors) | |
palette = [ | |
0, | |
0, | |
0, # Label 0 (black) | |
255, | |
0, | |
0, # Label 1 (red) | |
0, | |
255, | |
0, # Label 2 (green) | |
0, | |
0, | |
255, # Label 3 (blue) | |
255, | |
255, | |
0, # Label 4 (yellow) | |
255, | |
0, | |
255, # Label 5 (magenta) | |
0, | |
255, | |
255, # Label 6 (cyan) | |
128, | |
0, | |
0, # Label 7 (dark red) | |
0, | |
128, | |
0, # Label 8 (dark green) | |
0, | |
0, | |
128, # Label 9 (dark blue) | |
128, | |
128, | |
0, # Label 10 | |
128, | |
0, | |
128, # Label 11 | |
0, | |
128, | |
128, # Label 12 | |
64, | |
0, | |
0, # Label 13 | |
0, | |
64, | |
0, # Label 14 | |
0, | |
0, | |
64, # Label 15 | |
64, | |
64, | |
0, # Label 16 | |
64, | |
0, | |
64, # Label 17 | |
0, | |
64, | |
64, # Label 18 | |
192, | |
192, | |
192, # Label 19 (light gray) | |
128, | |
128, | |
128, # Label 20 (gray) | |
255, | |
165, | |
0, # Label 21 (orange) | |
75, | |
0, | |
130, # Label 22 (indigo) | |
238, | |
130, | |
238, # Label 23 (violet) | |
] | |
# Extend the palette to have 768 values (256 * 3) | |
palette.extend([0] * (768 - len(palette))) | |
return palette | |
PALETTE = create_palette() | |
# Result Utils | |
class BoundingBox: | |
xmin: int | |
ymin: int | |
xmax: int | |
ymax: int | |
def xyxy(self) -> List[float]: | |
return [self.xmin, self.ymin, self.xmax, self.ymax] | |
class DetectionResult: | |
score: Optional[float] = None | |
label: Optional[str] = None | |
box: Optional[BoundingBox] = None | |
mask: Optional[np.array] = None | |
def from_dict(cls, detection_dict: Dict) -> "DetectionResult": | |
return cls( | |
score=detection_dict["score"], | |
label=detection_dict["label"], | |
box=BoundingBox( | |
xmin=detection_dict["box"]["xmin"], | |
ymin=detection_dict["box"]["ymin"], | |
xmax=detection_dict["box"]["xmax"], | |
ymax=detection_dict["box"]["ymax"], | |
), | |
) | |
# Utils | |
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: | |
# Find contours in the binary mask | |
contours, _ = cv2.findContours( | |
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
) | |
# Find the contour with the largest area | |
largest_contour = max(contours, key=cv2.contourArea) | |
# Extract the vertices of the contour | |
polygon = largest_contour.reshape(-1, 2).tolist() | |
return polygon | |
def polygon_to_mask( | |
polygon: List[Tuple[int, int]], image_shape: Tuple[int, int] | |
) -> np.ndarray: | |
""" | |
Convert a polygon to a segmentation mask. | |
Args: | |
- polygon (list): List of (x, y) coordinates representing the vertices of the polygon. | |
- image_shape (tuple): Shape of the image (height, width) for the mask. | |
Returns: | |
- np.ndarray: Segmentation mask with the polygon filled. | |
""" | |
# Create an empty mask | |
mask = np.zeros(image_shape, dtype=np.uint8) | |
# Convert polygon to an array of points | |
pts = np.array(polygon, dtype=np.int32) | |
# Fill the polygon with white color (255) | |
cv2.fillPoly(mask, [pts], color=(255,)) | |
return mask | |
def load_image(image_str: str) -> Image.Image: | |
if image_str.startswith("http"): | |
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") | |
else: | |
image = Image.open(image_str).convert("RGB") | |
return image | |
def get_boxes(results: DetectionResult) -> List[List[List[float]]]: | |
boxes = [] | |
for result in results: | |
xyxy = result.box.xyxy | |
boxes.append(xyxy) | |
return [boxes] | |
def refine_masks( | |
masks: torch.BoolTensor, polygon_refinement: bool = False | |
) -> List[np.ndarray]: | |
masks = masks.cpu().float() | |
masks = masks.permute(0, 2, 3, 1) | |
masks = masks.mean(axis=-1) | |
masks = (masks > 0).int() | |
masks = masks.numpy().astype(np.uint8) | |
masks = list(masks) | |
if polygon_refinement: | |
for idx, mask in enumerate(masks): | |
shape = mask.shape | |
polygon = mask_to_polygon(mask) | |
mask = polygon_to_mask(polygon, shape) | |
masks[idx] = mask | |
return masks | |
# Post-processing Utils | |
def generate_colored_segmentation(label_image): | |
# Create a PIL Image from the label image (assuming it's a 2D numpy array) | |
label_image_pil = Image.fromarray(label_image.astype(np.uint8), mode="P") | |
# Apply the palette to the image | |
palette = create_palette() | |
label_image_pil.putpalette(palette) | |
return label_image_pil | |
def plot_segmentation(image, detections): | |
seg_map = np.zeros(image.size[::-1], dtype=np.uint8) | |
for i, detection in enumerate(detections): | |
mask = detection.mask | |
seg_map[mask > 0] = i + 1 | |
seg_map_pil = generate_colored_segmentation(seg_map) | |
return seg_map_pil | |
# Grounded SAM | |
def prepare_model( | |
device: str = "cuda", | |
detector_id: Optional[str] = None, | |
segmenter_id: Optional[str] = None, | |
): | |
detector_id = ( | |
detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" | |
) | |
object_detector = pipeline( | |
model=detector_id, task="zero-shot-object-detection", device=device | |
) | |
segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" | |
processor = AutoProcessor.from_pretrained(segmenter_id) | |
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) | |
return object_detector, processor, segmentator | |
def detect( | |
object_detector: Any, | |
image: Image.Image, | |
labels: List[str], | |
threshold: float = 0.3, | |
) -> List[Dict[str, Any]]: | |
""" | |
Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. | |
""" | |
labels = [label if label.endswith(".") else label + "." for label in labels] | |
results = object_detector(image, candidate_labels=labels, threshold=threshold) | |
results = [DetectionResult.from_dict(result) for result in results] | |
return results | |
def segment( | |
processor: Any, | |
segmentator: Any, | |
image: Image.Image, | |
boxes: Optional[List[List[List[float]]]] = None, | |
detection_results: Optional[List[Dict[str, Any]]] = None, | |
polygon_refinement: bool = False, | |
) -> List[DetectionResult]: | |
""" | |
Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. | |
""" | |
if detection_results is None and boxes is None: | |
raise ValueError( | |
"Either detection_results or detection_boxes must be provided." | |
) | |
if boxes is None: | |
boxes = get_boxes(detection_results) | |
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to( | |
segmentator.device, segmentator.dtype | |
) | |
outputs = segmentator(**inputs) | |
masks = processor.post_process_masks( | |
masks=outputs.pred_masks, | |
original_sizes=inputs.original_sizes, | |
reshaped_input_sizes=inputs.reshaped_input_sizes, | |
)[0] | |
masks = refine_masks(masks, polygon_refinement) | |
if detection_results is None: | |
detection_results = [DetectionResult() for _ in masks] | |
for detection_result, mask in zip(detection_results, masks): | |
detection_result.mask = mask | |
return detection_results | |
def grounded_segmentation( | |
object_detector, | |
processor, | |
segmentator, | |
image: Union[Image.Image, str], | |
labels: Union[str, List[str]], | |
threshold: float = 0.3, | |
polygon_refinement: bool = False, | |
) -> Tuple[np.ndarray, List[DetectionResult], Image.Image]: | |
if isinstance(image, str): | |
image = load_image(image) | |
if isinstance(labels, str): | |
labels = labels.split(",") | |
detections = detect(object_detector, image, labels, threshold) | |
detections = segment(processor, segmentator, image, detections, polygon_refinement) | |
seg_map_pil = plot_segmentation(image, detections) | |
return np.array(image), detections, seg_map_pil | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--image", type=str, required=True) | |
parser.add_argument("--labels", type=str, nargs="+", required=True) | |
parser.add_argument("--output", type=str, default="./", help="Output directory") | |
parser.add_argument("--threshold", type=float, default=0.3) | |
parser.add_argument( | |
"--detector_id", type=str, default="IDEA-Research/grounding-dino-base" | |
) | |
parser.add_argument("--segmenter_id", type=str, default="facebook/sam-vit-base") | |
args = parser.parse_args() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
object_detector, processor, segmentator = prepare_model( | |
device=device, detector_id=args.detector_id, segmenter_id=args.segmenter_id | |
) | |
image_array, detections, seg_map_pil = grounded_segmentation( | |
object_detector, | |
processor, | |
segmentator, | |
image=args.image, | |
labels=args.labels, | |
threshold=args.threshold, | |
polygon_refinement=True, | |
) | |
os.makedirs(args.output, exist_ok=True) | |
seg_map_pil.save(os.path.join(args.output, "segmentation.png")) | |