Spaces:
Running
Running
""" | |
Copyright (c) 2024-present Naver Cloud Corp. | |
This source code is based on code from the Segment Anything Model (SAM) | |
(https://github.com/facebookresearch/segment-anything). | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import os, sys | |
sys.path.append(os.getcwd()) | |
# Gradio demo, comparison SAM vs ZIM | |
import os | |
import torch | |
import gradio as gr | |
from gradio_image_prompter import ImagePrompter | |
import numpy as np | |
import cv2 | |
from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator | |
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator | |
from zim.utils import show_mat_anns | |
def get_shortest_axis(image): | |
h, w, _ = image.shape | |
return h if h < w else w | |
def reset_image(image, prompts): | |
if image is None: | |
image = np.zeros((1024, 1024, 3), dtype=np.uint8) | |
else: | |
image = image['image'] | |
zim_predictor.set_image(image) | |
sam_predictor.set_image(image) | |
prompts = dict() | |
black = np.zeros(image.shape[:2], dtype=np.uint8) | |
return (image, image, image, image, black, black, black, black, prompts) | |
def reset_example_image(image, prompts): | |
if image is None: | |
image = np.zeros((1024, 1024, 3), dtype=np.uint8) | |
zim_predictor.set_image(image) | |
sam_predictor.set_image(image) | |
prompts = dict() | |
black = np.zeros(image.shape[:2], dtype=np.uint8) | |
image_dict = {} | |
image_dict['image'] = image | |
image_dict['prompts'] = prompts | |
return (image, image_dict, image, image, image, black, black, black, black, prompts) | |
def run_amg(image): | |
gr.Info('Checkout ZIM Auto Mask tab.', duration=3) | |
zim_masks = zim_mask_generator.generate(image) | |
zim_masks_vis = show_mat_anns(image, zim_masks) | |
sam_masks = sam_mask_generator.generate(image) | |
sam_masks_vis = show_mat_anns(image, sam_masks) | |
return zim_masks_vis, sam_masks_vis | |
def run_model(image, prompts): | |
if not prompts: | |
raise gr.Error(f'Please input any point or BBox') | |
gr.Info('Checkout ZIM Mask tab.', duration=3) | |
point_coords = None | |
point_labels = None | |
boxes = None | |
if "point" in prompts: | |
point_coords, point_labels = [], [] | |
for type, pts in prompts["point"]: | |
point_coords.append(pts) | |
point_labels.append(type) | |
point_coords = np.array(point_coords) | |
point_labels = np.array(point_labels) | |
if "bbox" in prompts: | |
boxes = prompts['bbox'] | |
boxes = np.array(boxes) | |
if "scribble" in prompts: | |
point_coords, point_labels = [], [] | |
for pts in prompts["scribble"]: | |
point_coords.append(np.flip(pts)) | |
point_labels.append(1) | |
if len(point_coords) == 0: | |
raise gr.Error("Please input any scribbles.") | |
point_coords = np.array(point_coords) | |
point_labels = np.array(point_labels) | |
# run ZIM | |
zim_mask, _, _ = zim_predictor.predict( | |
point_coords=point_coords, | |
point_labels=point_labels, | |
box=boxes, | |
multimask_output=False, | |
) | |
zim_mask = np.squeeze(zim_mask, axis=0) | |
zim_mask = np.uint8(zim_mask * 255) | |
# run SAM | |
sam_mask, _, _ = sam_predictor.predict( | |
point_coords=point_coords, | |
point_labels=point_labels, | |
box=boxes, | |
multimask_output=False, | |
) | |
sam_mask = np.squeeze(sam_mask, axis=0) | |
sam_mask = np.uint8(sam_mask * 255) | |
return zim_mask, sam_mask | |
def reset_scribble(image, scribble, prompts): | |
# scribble = dict() | |
for k in prompts.keys(): | |
prompts[k] = [] | |
for k, v in scribble.items(): | |
scribble[k] = None | |
black = np.zeros(image.shape[:3], dtype=np.uint8) | |
return scribble, black, black | |
def update_scribble(image, scribble, prompts): | |
if "point" in prompts: | |
del prompts["point"] | |
if "bbox" in prompts: | |
del prompts["bbox"] | |
prompts = dict() # reset prompt | |
scribble_mask = scribble["layers"][0][..., -1] > 0 | |
scribble_coords = np.argwhere(scribble_mask) | |
n_points = min(len(scribble_coords), 24) | |
indices = np.linspace(0, len(scribble_coords)-1, n_points, dtype=int) | |
scribble_sampled = scribble_coords[indices] | |
prompts["scribble"] = scribble_sampled | |
zim_mask, sam_mask = run_model(image, prompts) | |
return zim_mask, sam_mask, prompts | |
def draw_point(img, pt, size, color): | |
# draw circle with white boundary region | |
cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 1.3), (255, 255, 255), -1) | |
cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 0.9), color, -1) | |
def draw_images(image, mask, prompts): | |
if len(prompts) == 0 or mask.shape[1] == 1: | |
return image, image, image | |
minor = get_shortest_axis(image) | |
size = int(minor / 80) | |
image = np.float32(image) | |
def blending(image, mask): | |
mask = np.float32(mask) / 255 | |
blended_image = np.zeros_like(image, dtype=np.float32) | |
blended_image[:, :, :] = [108, 0, 192] | |
blended_image = (image * 0.5) + (blended_image * 0.5) | |
img_with_mask = mask[:, :, None] * blended_image + (1 - mask[:, :, None]) * image | |
img_with_mask = np.uint8(img_with_mask) | |
return img_with_mask | |
img_with_mask = blending(image, mask) | |
img_with_point = img_with_mask.copy() | |
if "point" in prompts: | |
for type, pts in prompts["point"]: | |
if type == "Positive": | |
color = (0, 0, 255) | |
draw_point(img_with_point, pts, size, color) | |
elif type == "Negative": | |
color = (255, 0, 0) | |
draw_point(img_with_point, pts, size, color) | |
size = int(minor / 200) | |
return ( | |
img, | |
img_with_mask, | |
) | |
def get_point_or_box_prompts(img, prompts): | |
image, img_prompts = img['image'], img['points'] | |
point_prompts = [] | |
box_prompts = [] | |
for prompt in img_prompts: | |
for p in range(len(prompt)): | |
prompt[p] = int(prompt[p]) | |
if prompt[2] == 2 and prompt[5] == 3: # box prompt | |
if len(box_prompts) != 0: | |
raise gr.Error("Please input only one BBox.", duration=3) | |
box_prompts.append([prompt[0], prompt[1], prompt[3], prompt[4]]) | |
elif prompt[2] == 1 and prompt[5] == 4: # Positive point prompt | |
point_prompts.append((1, (prompt[0], prompt[1]))) | |
elif prompt[2] == 0 and prompt[5] == 4: # Negative point prompt | |
point_prompts.append((0, (prompt[0], prompt[1]))) | |
if "scribble" in prompts: | |
del prompts["scribble"] | |
if len(point_prompts) > 0: | |
prompts['point'] = point_prompts | |
elif 'point' in prompts: | |
del prompts['point'] | |
if len(box_prompts) > 0: | |
prompts['bbox'] = box_prompts | |
elif 'bbox' in prompts: | |
del prompts['bbox'] | |
zim_mask, sam_mask = run_model(image, prompts) | |
return image, zim_mask, sam_mask, prompts | |
def get_examples(): | |
assets_dir = os.path.join(os.path.dirname(__file__), 'examples') | |
images = os.listdir(assets_dir) | |
return [os.path.join(assets_dir, img) for img in images] | |
if __name__ == "__main__": | |
backbone = "vit_b" | |
# load ZIM | |
ckpt_mat = "ckpts/zim_vit_b_2043" | |
zim = zim_model_registry[backbone](checkpoint=ckpt_mat) | |
if torch.cuda.is_available(): | |
zim.cuda() | |
zim_predictor = ZimPredictor(zim) | |
zim_mask_generator = ZimAutomaticMaskGenerator( | |
zim, | |
pred_iou_thresh=0.7, | |
points_per_batch=8, | |
stability_score_thresh=0.9, | |
) | |
# load SAM | |
ckpt_sam = "ckpts/sam_vit_b_01ec64.pth" | |
sam = sam_model_registry[backbone](checkpoint=ckpt_sam) | |
if torch.cuda.is_available(): | |
sam.cuda() | |
sam_predictor = SamPredictor(sam) | |
sam_mask_generator = SamAutomaticMaskGenerator( | |
sam, | |
points_per_batch=8, | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown("# <center> [Demo] ZIM: Zero-Shot Image Matting for Anything") | |
prompts = gr.State(dict()) | |
img = gr.Image(visible=False) | |
example_image = gr.Image(visible=False) | |
with gr.Row(): | |
with gr.Column(): | |
# Point and Bbox prompt | |
with gr.Tab(label="Point or Box"): | |
img_with_point_or_box = ImagePrompter( | |
label="query image", | |
sources="upload" | |
) | |
interactions = "Left Click (Pos) | Middle/Right Click (Neg) | Press Move (Box)" | |
gr.Markdown("<h3 style='text-align: center'> {} </h3>".format(interactions)) | |
run_bttn = gr.Button("Run") | |
amg_bttn = gr.Button("Automatic Mask Generation") | |
# Scribble prompt | |
with gr.Tab(label="Scribble"): | |
img_with_scribble = gr.ImageEditor( | |
label="Scribble", | |
brush=gr.Brush(colors=["#00FF00"], default_size=15), | |
sources="upload", | |
transforms=None, | |
layers=False | |
) | |
interactions = "Press Move (Scribble)" | |
gr.Markdown("<h3 style='text-align: center'> Step 1. Select Draw button </h3>") | |
gr.Markdown("<h3 style='text-align: center'> Step 2. {} </h3>".format(interactions)) | |
scribble_bttn = gr.Button("Run") | |
scribble_reset_bttn = gr.Button("Reset Scribbles") | |
amg_scribble_bttn = gr.Button("Automatic Mask Generation") | |
# Example image | |
gr.Examples(get_examples(), inputs=[example_image]) | |
# with gr.Row(): | |
with gr.Column(): | |
with gr.Tab(label="ZIM Image"): | |
img_with_zim_mask = gr.Image( | |
label="ZIM Image", | |
interactive=False | |
) | |
with gr.Tab(label="ZIM Mask"): | |
zim_mask = gr.Image( | |
label="ZIM Mask", | |
image_mode="L", | |
interactive=False | |
) | |
with gr.Tab(label="ZIM Auto Mask"): | |
zim_amg = gr.Image( | |
label="ZIM Auto Mask", | |
interactive=False | |
) | |
with gr.Column(): | |
with gr.Tab(label="SAM Image"): | |
img_with_sam_mask = gr.Image( | |
label="SAM image", | |
interactive=False | |
) | |
with gr.Tab(label="SAM Mask"): | |
sam_mask = gr.Image( | |
label="SAM Mask", | |
image_mode="L", | |
interactive=False | |
) | |
with gr.Tab(label="SAM Auto Mask"): | |
sam_amg = gr.Image( | |
label="SAM Auto Mask", | |
interactive=False | |
) | |
example_image.change( | |
reset_example_image, | |
[example_image, prompts], | |
[ | |
img, | |
img_with_point_or_box, | |
img_with_scribble, | |
img_with_zim_mask, | |
img_with_sam_mask, | |
zim_amg, | |
sam_amg, | |
zim_mask, | |
sam_mask, | |
prompts, | |
] | |
) | |
img_with_point_or_box.upload( | |
reset_image, | |
[img_with_point_or_box, prompts], | |
[ | |
img, | |
img_with_scribble, | |
img_with_zim_mask, | |
img_with_sam_mask, | |
zim_amg, | |
sam_amg, | |
zim_mask, | |
sam_mask, | |
prompts, | |
], | |
) | |
amg_bttn.click( | |
run_amg, | |
[img], | |
[zim_amg, sam_amg] | |
) | |
amg_scribble_bttn.click( | |
run_amg, | |
[img], | |
[zim_amg, sam_amg] | |
) | |
run_bttn.click( | |
get_point_or_box_prompts, | |
[img_with_point_or_box, prompts], | |
[img, zim_mask, sam_mask, prompts] | |
) | |
zim_mask.change( | |
draw_images, | |
[img, zim_mask, prompts], | |
[ | |
img, img_with_zim_mask, | |
], | |
) | |
sam_mask.change( | |
draw_images, | |
[img, sam_mask, prompts], | |
[ | |
img, img_with_sam_mask, | |
], | |
) | |
scribble_reset_bttn.click( | |
reset_scribble, | |
[img, img_with_scribble, prompts], | |
[img_with_scribble, zim_mask, sam_mask], | |
) | |
scribble_bttn.click( | |
update_scribble, | |
[img, img_with_scribble, prompts], | |
[zim_mask, sam_mask, prompts], | |
) | |
demo.queue() | |
demo.launch() |