""" Copyright (c) 2024-present Naver Cloud Corp. This source code is based on code from the Segment Anything Model (SAM) (https://github.com/facebookresearch/segment-anything). This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import os, sys sys.path.append(os.getcwd()) # Gradio demo, comparison SAM vs ZIM import os import torch import gradio as gr from gradio_image_prompter import ImagePrompter import numpy as np import cv2 from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator from zim.utils import show_mat_anns from huggingface_hub import hf_hub_download def get_shortest_axis(image): h, w, _ = image.shape return h if h < w else w def reset_image(image, prompts): if image is None: image = np.zeros((1024, 1024, 3), dtype=np.uint8) else: image = image['image'] zim_predictor.set_image(image) sam_predictor.set_image(image) prompts = dict() black = np.zeros(image.shape[:2], dtype=np.uint8) return (image, image, image, image, black, black, black, black, prompts) def reset_example_image(image, prompts): if image is None: image = np.zeros((1024, 1024, 3), dtype=np.uint8) zim_predictor.set_image(image) sam_predictor.set_image(image) prompts = dict() black = np.zeros(image.shape[:2], dtype=np.uint8) image_dict = {} image_dict['image'] = image image_dict['prompts'] = prompts return (image, image_dict, image, image, image, black, black, black, black, prompts) def run_amg(image): gr.Info('Checkout ZIM Auto Mask tab.', duration=3) zim_masks = zim_mask_generator.generate(image) zim_masks_vis = show_mat_anns(image, zim_masks) sam_masks = sam_mask_generator.generate(image) sam_masks_vis = show_mat_anns(image, sam_masks) return zim_masks_vis, sam_masks_vis def run_model(image, prompts): if not prompts: raise gr.Error(f'Please input any point or BBox') gr.Info('Checkout ZIM Mask tab.', duration=3) point_coords = None point_labels = None boxes = None if "point" in prompts: point_coords, point_labels = [], [] for type, pts in prompts["point"]: point_coords.append(pts) point_labels.append(type) point_coords = np.array(point_coords) point_labels = np.array(point_labels) if "bbox" in prompts: boxes = prompts['bbox'] boxes = np.array(boxes) if "scribble" in prompts: point_coords, point_labels = [], [] for pts in prompts["scribble"]: point_coords.append(np.flip(pts)) point_labels.append(1) if len(point_coords) == 0: raise gr.Error("Please input any scribbles.") point_coords = np.array(point_coords) point_labels = np.array(point_labels) # run ZIM zim_mask, _, _ = zim_predictor.predict( point_coords=point_coords, point_labels=point_labels, box=boxes, multimask_output=False, ) zim_mask = np.squeeze(zim_mask, axis=0) zim_mask = np.uint8(zim_mask * 255) # run SAM sam_mask, _, _ = sam_predictor.predict( point_coords=point_coords, point_labels=point_labels, box=boxes, multimask_output=False, ) sam_mask = np.squeeze(sam_mask, axis=0) sam_mask = np.uint8(sam_mask * 255) return zim_mask, sam_mask def reset_scribble(image, scribble, prompts): # scribble = dict() for k in prompts.keys(): prompts[k] = [] for k, v in scribble.items(): scribble[k] = None black = np.zeros(image.shape[:3], dtype=np.uint8) return scribble, black, black def update_scribble(image, scribble, prompts): if "point" in prompts: del prompts["point"] if "bbox" in prompts: del prompts["bbox"] prompts = dict() # reset prompt scribble_mask = scribble["layers"][0][..., -1] > 0 scribble_coords = np.argwhere(scribble_mask) n_points = min(len(scribble_coords), 24) indices = np.linspace(0, len(scribble_coords)-1, n_points, dtype=int) scribble_sampled = scribble_coords[indices] prompts["scribble"] = scribble_sampled zim_mask, sam_mask = run_model(image, prompts) return zim_mask, sam_mask, prompts def draw_point(img, pt, size, color): # draw circle with white boundary region cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 1.3), (255, 255, 255), -1) cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 0.9), color, -1) def draw_images(image, mask, prompts): if len(prompts) == 0 or mask.shape[1] == 1: return image, image, image minor = get_shortest_axis(image) size = int(minor / 80) image = np.float32(image) def blending(image, mask): mask = np.float32(mask) / 255 blended_image = np.zeros_like(image, dtype=np.float32) blended_image[:, :, :] = [108, 0, 192] blended_image = (image * 0.5) + (blended_image * 0.5) img_with_mask = mask[:, :, None] * blended_image + (1 - mask[:, :, None]) * image img_with_mask = np.uint8(img_with_mask) return img_with_mask img_with_mask = blending(image, mask) img_with_point = img_with_mask.copy() if "point" in prompts: for type, pts in prompts["point"]: if type == "Positive": color = (0, 0, 255) draw_point(img_with_point, pts, size, color) elif type == "Negative": color = (255, 0, 0) draw_point(img_with_point, pts, size, color) size = int(minor / 200) return ( img, img_with_mask, ) def get_point_or_box_prompts(img, prompts): image, img_prompts = img['image'], img['points'] point_prompts = [] box_prompts = [] for prompt in img_prompts: for p in range(len(prompt)): prompt[p] = int(prompt[p]) if prompt[2] == 2 and prompt[5] == 3: # box prompt if len(box_prompts) != 0: raise gr.Error("Please input only one BBox.", duration=3) box_prompts.append([prompt[0], prompt[1], prompt[3], prompt[4]]) elif prompt[2] == 1 and prompt[5] == 4: # Positive point prompt point_prompts.append((1, (prompt[0], prompt[1]))) elif prompt[2] == 0 and prompt[5] == 4: # Negative point prompt point_prompts.append((0, (prompt[0], prompt[1]))) if "scribble" in prompts: del prompts["scribble"] if len(point_prompts) > 0: prompts['point'] = point_prompts elif 'point' in prompts: del prompts['point'] if len(box_prompts) > 0: prompts['bbox'] = box_prompts elif 'bbox' in prompts: del prompts['bbox'] zim_mask, sam_mask = run_model(image, prompts) return image, zim_mask, sam_mask, prompts def get_examples(): assets_dir = os.path.join(os.path.dirname(__file__), 'examples') images = os.listdir(assets_dir) return [os.path.join(assets_dir, img) for img in images] def download_onnx_weights(repo_id="naver-iv/zim-anything-vitb", file_dir="zim_vit_b_2043"): hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/encoder.onnx") filepath = hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/decoder.onnx") return os.path.dirname(filepath) if __name__ == "__main__": backbone = "vit_b" # load ZIM zim = zim_model_registry[backbone](checkpoint=download_onnx_weights()) if torch.cuda.is_available(): zim.cuda() zim_predictor = ZimPredictor(zim) zim_mask_generator = ZimAutomaticMaskGenerator( zim, pred_iou_thresh=0.7, points_per_batch=8, stability_score_thresh=0.9, ) # load SAM ckpt_sam = "ckpts/sam_vit_b_01ec64.pth" sam = sam_model_registry[backbone](checkpoint=ckpt_sam) if torch.cuda.is_available(): sam.cuda() sam_predictor = SamPredictor(sam) sam_mask_generator = SamAutomaticMaskGenerator( sam, points_per_batch=8, ) with gr.Blocks() as demo: gr.Markdown("#