|
import gradio as gr |
|
import torch |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
from groundingdino.util.inference import load_model, load_image, predict |
|
from segment_anything import SamPredictor, sam_model_registry |
|
from torchvision.ops import box_convert |
|
|
|
model_type = "vit_b" |
|
sam_checkpoint = "weights/sam_vit_b.pth" |
|
config = "groundingdino/config/GroundingDINO_SwinT_OGC.py" |
|
dino_checkpoint = "weights/groundingdino_swint_ogc.pth" |
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
predictor = SamPredictor(sam) |
|
device = "cpu" |
|
model = load_model(config, dino_checkpoint, device) |
|
box_threshold = 0.35 |
|
text_threshold = 0.25 |
|
|
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
def show_box(box, ax, label = None): |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=2)) |
|
if label is not None: |
|
ax.text(x0, y0, label, fontsize=12, color='white', backgroundcolor='red', ha='left', va='top') |
|
|
|
def extract_object_with_transparent_background(image, masks): |
|
mask_expanded = np.expand_dims(masks[0], axis=-1) |
|
mask_expanded = np.repeat(mask_expanded, 3, axis=-1) |
|
segment = image * mask_expanded |
|
rgba_segment = np.zeros((segment.shape[0], segment.shape[1], 4), dtype=np.uint8) |
|
rgba_segment[:, :, :3] = segment |
|
rgba_segment[:, :, 3] = masks[0] * 255 |
|
return rgba_segment |
|
|
|
def extract_remaining_image(image, masks): |
|
inverse_mask = np.logical_not(masks[0]) |
|
inverse_mask_expanded = np.expand_dims(inverse_mask, axis=-1) |
|
inverse_mask_expanded = np.repeat(inverse_mask_expanded, 3, axis=-1) |
|
remaining_image = image * inverse_mask_expanded |
|
return remaining_image |
|
|
|
def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_boxes): |
|
fig, ax = plt.subplots() |
|
ax.imshow(image) |
|
if show_masks: |
|
for mask in masks: |
|
show_mask(mask, ax, random_color=False) |
|
|
|
if show_boxes: |
|
for input_box, label in zip(boxes, labels): |
|
show_box(input_box, ax, label) |
|
|
|
ax.axis('off') |
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) |
|
plt.margins(0, 0) |
|
|
|
fig.canvas.draw() |
|
output_image = np.array(fig.canvas.buffer_rgba()) |
|
|
|
plt.close(fig) |
|
return output_image |
|
|
|
|
|
def detect_objects(image, prompt, show_masks, show_boxes, crop_options): |
|
image_source, image = load_image(image) |
|
predictor.set_image(image_source) |
|
|
|
boxes, logits, phrases = predict( |
|
model=model, |
|
image=image, |
|
caption=prompt, |
|
box_threshold=box_threshold, |
|
text_threshold=text_threshold, |
|
device=device |
|
) |
|
|
|
h, w, _ = image_source.shape |
|
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") * torch.Tensor([w, h, w, h]) |
|
boxes = np.round(boxes.numpy()).astype(int) |
|
|
|
labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)] |
|
|
|
masks_list = [] |
|
|
|
for input_box, label in zip(boxes, labels): |
|
x1, y1, x2, y2 = input_box |
|
width = x2 - x1 |
|
height = y2 - y1 |
|
avg_size = (width + height) / 2 |
|
d = avg_size * 0.1 |
|
|
|
center_point = np.array([(x1 + x2) / 2, (y1 + y2) / 2]) |
|
points = [] |
|
points.append([center_point[0], center_point[1] - d]) |
|
points.append([center_point[0], center_point[1] + d]) |
|
points.append([center_point[0] - d, center_point[1]]) |
|
points.append([center_point[0] + d, center_point[1]]) |
|
input_point = np.array(points) |
|
input_label = np.array([1] * len(input_point)) |
|
|
|
masks, scores, logits = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=True, |
|
) |
|
mask_input = logits[np.argmax(scores), :, :] |
|
|
|
masks, _, _ = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
mask_input=mask_input[None, :, :], |
|
multimask_output=False |
|
) |
|
masks_list.append(masks) |
|
|
|
if crop_options == "Crop": |
|
composite_image = np.zeros_like(image_source) |
|
for masks in masks_list: |
|
rgba_segment = extract_object_with_transparent_background(image_source, masks) |
|
composite_image = np.maximum(composite_image, rgba_segment[:, :, :3]) |
|
output_image = overlay_masks_boxes_on_image(composite_image, masks_list, boxes, labels, show_masks, show_boxes) |
|
elif crop_options == "Inverse Crop": |
|
remaining_image = image_source.copy() |
|
for masks in masks_list: |
|
remaining_image = extract_remaining_image(remaining_image, masks) |
|
output_image = overlay_masks_boxes_on_image(remaining_image, masks_list, boxes, labels, show_masks, show_boxes) |
|
else: |
|
output_image = overlay_masks_boxes_on_image(image_source, masks_list, boxes, labels, show_masks, show_boxes) |
|
|
|
output_image_path = 'output_image.jpeg' |
|
plt.imsave(output_image_path, output_image) |
|
|
|
return output_image_path |
|
|
|
block = gr.Blocks(css=".gradio-container {background-color: #f8f8f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif}") |
|
|
|
with block: |
|
gr.HTML(""" |
|
<style> |
|
body { |
|
background-color: #f5f5f5; |
|
font-family: 'Roboto', sans-serif; |
|
padding: 30px; |
|
} |
|
</style> |
|
""") |
|
|
|
gr.HTML("<h1 style='text-align: center;'>Segment Any Image</h1>") |
|
gr.HTML("<h3 style='text-align: center;'>Zero-Shot Object Detection, Segmentation and Cropping</h3>") |
|
with gr.Row(): |
|
with gr.Column(width="auto"): |
|
input_image = gr.Image(type='filepath', label="Upload Image") |
|
with gr.Column(width="auto"): |
|
output_image = gr.Image(type='filepath', label="Result") |
|
with gr.Row(): |
|
with gr.Column(width="auto"): |
|
object_search = gr.Textbox( |
|
label="Object to Detect", |
|
placeholder="Enter any text, comma separated if multiple objects needed", |
|
show_label=True, |
|
lines=1, |
|
) |
|
with gr.Column(width="auto"): |
|
show_masks = gr.Checkbox(label="Show Masks", default=True) |
|
show_boxes = gr.Checkbox(label="Show Boxes", default=True) |
|
with gr.Column(width="auto"): |
|
crop_options = gr.Radio(choices=["None", "Crop", "Inverse Crop"], label="Crop Options", default="None") |
|
with gr.Row(): |
|
submit = gr.Button(value="Send", variant="secondary").style(full_width=True) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["images/tiger.jpeg", "animal from cat family", True, True], |
|
["images/car.jpeg", "a blue sports car", True, False], |
|
["images/bags.jpeg", "black bag next to the red bag", False, True], |
|
["images/deer.jpeg", "deer jumping and running across the road", True, True], |
|
["images/penn.jpeg", "sign board", True, False], |
|
], |
|
inputs=[input_image, object_search, show_masks, show_boxes], |
|
) |
|
gr.HTML(""" |
|
<div style="text-align:center"> |
|
<p>Developed by <a href='https://www.linkedin.com/in/dekay/'>Github and Huggingface: Volkopat</a></p> |
|
<p>Powered by <a href='https://segment-anything.com'>Segment Anything</a> and <a href='https://arxiv.org/abs/2303.05499'>Grounding DINO</a></p> |
|
<p>Just upload an image and enter the objects to detect, segment, crop, etc. That's all folks!</p> |
|
<p>What's Zero-Shot? It means you can detect objects without any training samples!</p> |
|
<p>This project is for demonstration purposes. Credits for State of the Art models go to Meta AI and IDEA Research.</p> |
|
</div> |
|
<style> |
|
p { |
|
margin-bottom: 10px; |
|
font-size: 16px; |
|
} |
|
a { |
|
color: #3867d6; |
|
text-decoration: none; |
|
} |
|
a:hover { |
|
text-decoration: underline; |
|
} |
|
</style> |
|
""") |
|
|
|
submit.click(fn=detect_objects, |
|
inputs=[input_image, object_search, show_masks, show_boxes, crop_options], |
|
outputs=[output_image]) |
|
|
|
block.launch(width=800) |