scyonggg's picture
Initial commit
9860a06
raw
history blame
13.1 kB
"""
Copyright (c) 2024-present Naver Cloud Corp.
This source code is based on code from the Segment Anything Model (SAM)
(https://github.com/facebookresearch/segment-anything).
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os, sys
sys.path.append(os.getcwd())
# Gradio demo, comparison SAM vs ZIM
import os
import torch
import gradio as gr
from gradio_image_prompter import ImagePrompter
import numpy as np
import cv2
from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from zim.utils import show_mat_anns
def get_shortest_axis(image):
h, w, _ = image.shape
return h if h < w else w
def reset_image(image, prompts):
if image is None:
image = np.zeros((1024, 1024, 3), dtype=np.uint8)
else:
image = image['image']
zim_predictor.set_image(image)
sam_predictor.set_image(image)
prompts = dict()
black = np.zeros(image.shape[:2], dtype=np.uint8)
return (image, image, image, image, black, black, black, black, prompts)
def reset_example_image(image, prompts):
if image is None:
image = np.zeros((1024, 1024, 3), dtype=np.uint8)
zim_predictor.set_image(image)
sam_predictor.set_image(image)
prompts = dict()
black = np.zeros(image.shape[:2], dtype=np.uint8)
image_dict = {}
image_dict['image'] = image
image_dict['prompts'] = prompts
return (image, image_dict, image, image, image, black, black, black, black, prompts)
def run_amg(image):
gr.Info('Checkout ZIM Auto Mask tab.', duration=3)
zim_masks = zim_mask_generator.generate(image)
zim_masks_vis = show_mat_anns(image, zim_masks)
sam_masks = sam_mask_generator.generate(image)
sam_masks_vis = show_mat_anns(image, sam_masks)
return zim_masks_vis, sam_masks_vis
def run_model(image, prompts):
if not prompts:
raise gr.Error(f'Please input any point or BBox')
gr.Info('Checkout ZIM Mask tab.', duration=3)
point_coords = None
point_labels = None
boxes = None
if "point" in prompts:
point_coords, point_labels = [], []
for type, pts in prompts["point"]:
point_coords.append(pts)
point_labels.append(type)
point_coords = np.array(point_coords)
point_labels = np.array(point_labels)
if "bbox" in prompts:
boxes = prompts['bbox']
boxes = np.array(boxes)
if "scribble" in prompts:
point_coords, point_labels = [], []
for pts in prompts["scribble"]:
point_coords.append(np.flip(pts))
point_labels.append(1)
if len(point_coords) == 0:
raise gr.Error("Please input any scribbles.")
point_coords = np.array(point_coords)
point_labels = np.array(point_labels)
# run ZIM
zim_mask, _, _ = zim_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=boxes,
multimask_output=False,
)
zim_mask = np.squeeze(zim_mask, axis=0)
zim_mask = np.uint8(zim_mask * 255)
# run SAM
sam_mask, _, _ = sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=boxes,
multimask_output=False,
)
sam_mask = np.squeeze(sam_mask, axis=0)
sam_mask = np.uint8(sam_mask * 255)
return zim_mask, sam_mask
def reset_scribble(image, scribble, prompts):
# scribble = dict()
for k in prompts.keys():
prompts[k] = []
for k, v in scribble.items():
scribble[k] = None
black = np.zeros(image.shape[:3], dtype=np.uint8)
return scribble, black, black
def update_scribble(image, scribble, prompts):
if "point" in prompts:
del prompts["point"]
if "bbox" in prompts:
del prompts["bbox"]
prompts = dict() # reset prompt
scribble_mask = scribble["layers"][0][..., -1] > 0
scribble_coords = np.argwhere(scribble_mask)
n_points = min(len(scribble_coords), 24)
indices = np.linspace(0, len(scribble_coords)-1, n_points, dtype=int)
scribble_sampled = scribble_coords[indices]
prompts["scribble"] = scribble_sampled
zim_mask, sam_mask = run_model(image, prompts)
return zim_mask, sam_mask, prompts
def draw_point(img, pt, size, color):
# draw circle with white boundary region
cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 1.3), (255, 255, 255), -1)
cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 0.9), color, -1)
def draw_images(image, mask, prompts):
if len(prompts) == 0 or mask.shape[1] == 1:
return image, image, image
minor = get_shortest_axis(image)
size = int(minor / 80)
image = np.float32(image)
def blending(image, mask):
mask = np.float32(mask) / 255
blended_image = np.zeros_like(image, dtype=np.float32)
blended_image[:, :, :] = [108, 0, 192]
blended_image = (image * 0.5) + (blended_image * 0.5)
img_with_mask = mask[:, :, None] * blended_image + (1 - mask[:, :, None]) * image
img_with_mask = np.uint8(img_with_mask)
return img_with_mask
img_with_mask = blending(image, mask)
img_with_point = img_with_mask.copy()
if "point" in prompts:
for type, pts in prompts["point"]:
if type == "Positive":
color = (0, 0, 255)
draw_point(img_with_point, pts, size, color)
elif type == "Negative":
color = (255, 0, 0)
draw_point(img_with_point, pts, size, color)
size = int(minor / 200)
return (
img,
img_with_mask,
)
def get_point_or_box_prompts(img, prompts):
image, img_prompts = img['image'], img['points']
point_prompts = []
box_prompts = []
for prompt in img_prompts:
for p in range(len(prompt)):
prompt[p] = int(prompt[p])
if prompt[2] == 2 and prompt[5] == 3: # box prompt
if len(box_prompts) != 0:
raise gr.Error("Please input only one BBox.", duration=3)
box_prompts.append([prompt[0], prompt[1], prompt[3], prompt[4]])
elif prompt[2] == 1 and prompt[5] == 4: # Positive point prompt
point_prompts.append((1, (prompt[0], prompt[1])))
elif prompt[2] == 0 and prompt[5] == 4: # Negative point prompt
point_prompts.append((0, (prompt[0], prompt[1])))
if "scribble" in prompts:
del prompts["scribble"]
if len(point_prompts) > 0:
prompts['point'] = point_prompts
elif 'point' in prompts:
del prompts['point']
if len(box_prompts) > 0:
prompts['bbox'] = box_prompts
elif 'bbox' in prompts:
del prompts['bbox']
zim_mask, sam_mask = run_model(image, prompts)
return image, zim_mask, sam_mask, prompts
def get_examples():
assets_dir = os.path.join(os.path.dirname(__file__), 'examples')
images = os.listdir(assets_dir)
return [os.path.join(assets_dir, img) for img in images]
if __name__ == "__main__":
backbone = "vit_b"
# load ZIM
ckpt_mat = "ckpts/zim_vit_b_2043"
zim = zim_model_registry[backbone](checkpoint=ckpt_mat)
if torch.cuda.is_available():
zim.cuda()
zim_predictor = ZimPredictor(zim)
zim_mask_generator = ZimAutomaticMaskGenerator(
zim,
pred_iou_thresh=0.7,
points_per_batch=8,
stability_score_thresh=0.9,
)
# load SAM
ckpt_sam = "ckpts/sam_vit_b_01ec64.pth"
sam = sam_model_registry[backbone](checkpoint=ckpt_sam)
if torch.cuda.is_available():
sam.cuda()
sam_predictor = SamPredictor(sam)
sam_mask_generator = SamAutomaticMaskGenerator(
sam,
points_per_batch=8,
)
with gr.Blocks() as demo:
gr.Markdown("# <center> [Demo] ZIM: Zero-Shot Image Matting for Anything")
prompts = gr.State(dict())
img = gr.Image(visible=False)
example_image = gr.Image(visible=False)
with gr.Row():
with gr.Column():
# Point and Bbox prompt
with gr.Tab(label="Point or Box"):
img_with_point_or_box = ImagePrompter(
label="query image",
sources="upload"
)
interactions = "Left Click (Pos) | Middle/Right Click (Neg) | Press Move (Box)"
gr.Markdown("<h3 style='text-align: center'> {} </h3>".format(interactions))
run_bttn = gr.Button("Run")
amg_bttn = gr.Button("Automatic Mask Generation")
# Scribble prompt
with gr.Tab(label="Scribble"):
img_with_scribble = gr.ImageEditor(
label="Scribble",
brush=gr.Brush(colors=["#00FF00"], default_size=15),
sources="upload",
transforms=None,
layers=False
)
interactions = "Press Move (Scribble)"
gr.Markdown("<h3 style='text-align: center'> Step 1. Select Draw button </h3>")
gr.Markdown("<h3 style='text-align: center'> Step 2. {} </h3>".format(interactions))
scribble_bttn = gr.Button("Run")
scribble_reset_bttn = gr.Button("Reset Scribbles")
amg_scribble_bttn = gr.Button("Automatic Mask Generation")
# Example image
gr.Examples(get_examples(), inputs=[example_image])
# with gr.Row():
with gr.Column():
with gr.Tab(label="ZIM Image"):
img_with_zim_mask = gr.Image(
label="ZIM Image",
interactive=False
)
with gr.Tab(label="ZIM Mask"):
zim_mask = gr.Image(
label="ZIM Mask",
image_mode="L",
interactive=False
)
with gr.Tab(label="ZIM Auto Mask"):
zim_amg = gr.Image(
label="ZIM Auto Mask",
interactive=False
)
with gr.Column():
with gr.Tab(label="SAM Image"):
img_with_sam_mask = gr.Image(
label="SAM image",
interactive=False
)
with gr.Tab(label="SAM Mask"):
sam_mask = gr.Image(
label="SAM Mask",
image_mode="L",
interactive=False
)
with gr.Tab(label="SAM Auto Mask"):
sam_amg = gr.Image(
label="SAM Auto Mask",
interactive=False
)
example_image.change(
reset_example_image,
[example_image, prompts],
[
img,
img_with_point_or_box,
img_with_scribble,
img_with_zim_mask,
img_with_sam_mask,
zim_amg,
sam_amg,
zim_mask,
sam_mask,
prompts,
]
)
img_with_point_or_box.upload(
reset_image,
[img_with_point_or_box, prompts],
[
img,
img_with_scribble,
img_with_zim_mask,
img_with_sam_mask,
zim_amg,
sam_amg,
zim_mask,
sam_mask,
prompts,
],
)
amg_bttn.click(
run_amg,
[img],
[zim_amg, sam_amg]
)
amg_scribble_bttn.click(
run_amg,
[img],
[zim_amg, sam_amg]
)
run_bttn.click(
get_point_or_box_prompts,
[img_with_point_or_box, prompts],
[img, zim_mask, sam_mask, prompts]
)
zim_mask.change(
draw_images,
[img, zim_mask, prompts],
[
img, img_with_zim_mask,
],
)
sam_mask.change(
draw_images,
[img, sam_mask, prompts],
[
img, img_with_sam_mask,
],
)
scribble_reset_bttn.click(
reset_scribble,
[img, img_with_scribble, prompts],
[img_with_scribble, zim_mask, sam_mask],
)
scribble_bttn.click(
update_scribble,
[img, img_with_scribble, prompts],
[zim_mask, sam_mask, prompts],
)
demo.queue()
demo.launch()