Volko
first push
6b7f843
import gradio as gr
import torch
from matplotlib import pyplot as plt
import numpy as np
from groundingdino.util.inference import load_model, load_image, predict
from segment_anything import SamPredictor, sam_model_registry
from torchvision.ops import box_convert
model_type = "vit_b"
sam_checkpoint = "weights/sam_vit_b.pth"
config = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
dino_checkpoint = "weights/groundingdino_swint_ogc.pth"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
predictor = SamPredictor(sam)
device = "cpu"
model = load_model(config, dino_checkpoint, device)
box_threshold = 0.35
text_threshold = 0.25
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax, label = None):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=2))
if label is not None:
ax.text(x0, y0, label, fontsize=12, color='white', backgroundcolor='red', ha='left', va='top')
def extract_object_with_transparent_background(image, masks):
mask_expanded = np.expand_dims(masks[0], axis=-1)
mask_expanded = np.repeat(mask_expanded, 3, axis=-1)
segment = image * mask_expanded
rgba_segment = np.zeros((segment.shape[0], segment.shape[1], 4), dtype=np.uint8)
rgba_segment[:, :, :3] = segment
rgba_segment[:, :, 3] = masks[0] * 255
return rgba_segment
def extract_remaining_image(image, masks):
inverse_mask = np.logical_not(masks[0])
inverse_mask_expanded = np.expand_dims(inverse_mask, axis=-1)
inverse_mask_expanded = np.repeat(inverse_mask_expanded, 3, axis=-1)
remaining_image = image * inverse_mask_expanded
return remaining_image
def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_boxes):
fig, ax = plt.subplots()
ax.imshow(image)
if show_masks:
for mask in masks:
show_mask(mask, ax, random_color=False)
if show_boxes:
for input_box, label in zip(boxes, labels):
show_box(input_box, ax, label)
ax.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
plt.margins(0, 0)
fig.canvas.draw()
output_image = np.array(fig.canvas.buffer_rgba())
plt.close(fig)
return output_image
def detect_objects(image, prompt, show_masks, show_boxes, crop_options):
image_source, image = load_image(image)
predictor.set_image(image_source)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
device=device
)
h, w, _ = image_source.shape
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") * torch.Tensor([w, h, w, h])
boxes = np.round(boxes.numpy()).astype(int)
labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)]
masks_list = []
for input_box, label in zip(boxes, labels):
x1, y1, x2, y2 = input_box
width = x2 - x1
height = y2 - y1
avg_size = (width + height) / 2
d = avg_size * 0.1
center_point = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
points = []
points.append([center_point[0], center_point[1] - d])
points.append([center_point[0], center_point[1] + d])
points.append([center_point[0] - d, center_point[1]])
points.append([center_point[0] + d, center_point[1]])
input_point = np.array(points)
input_label = np.array([1] * len(input_point))
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
mask_input = logits[np.argmax(scores), :, :]
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False
)
masks_list.append(masks)
if crop_options == "Crop":
composite_image = np.zeros_like(image_source)
for masks in masks_list:
rgba_segment = extract_object_with_transparent_background(image_source, masks)
composite_image = np.maximum(composite_image, rgba_segment[:, :, :3])
output_image = overlay_masks_boxes_on_image(composite_image, masks_list, boxes, labels, show_masks, show_boxes)
elif crop_options == "Inverse Crop":
remaining_image = image_source.copy()
for masks in masks_list:
remaining_image = extract_remaining_image(remaining_image, masks)
output_image = overlay_masks_boxes_on_image(remaining_image, masks_list, boxes, labels, show_masks, show_boxes)
else:
output_image = overlay_masks_boxes_on_image(image_source, masks_list, boxes, labels, show_masks, show_boxes)
output_image_path = 'output_image.jpeg'
plt.imsave(output_image_path, output_image)
return output_image_path
block = gr.Blocks(css=".gradio-container {background-color: #f8f8f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif}")
with block:
gr.HTML("""
<style>
body {
background-color: #f5f5f5;
font-family: 'Roboto', sans-serif;
padding: 30px;
}
</style>
""")
gr.HTML("<h1 style='text-align: center;'>Segment Any Image</h1>")
gr.HTML("<h3 style='text-align: center;'>Zero-Shot Object Detection, Segmentation and Cropping</h3>")
with gr.Row():
with gr.Column(width="auto"):
input_image = gr.Image(type='filepath', label="Upload Image")
with gr.Column(width="auto"):
output_image = gr.Image(type='filepath', label="Result")
with gr.Row():
with gr.Column(width="auto"):
object_search = gr.Textbox(
label="Object to Detect",
placeholder="Enter any text, comma separated if multiple objects needed",
show_label=True,
lines=1,
)
with gr.Column(width="auto"):
show_masks = gr.Checkbox(label="Show Masks", default=True)
show_boxes = gr.Checkbox(label="Show Boxes", default=True)
with gr.Column(width="auto"):
crop_options = gr.Radio(choices=["None", "Crop", "Inverse Crop"], label="Crop Options", default="None")
with gr.Row():
submit = gr.Button(value="Send", variant="secondary").style(full_width=True)
gr.Examples(
examples=[
["images/tiger.jpeg", "animal from cat family", True, True],
["images/car.jpeg", "a blue sports car", True, False],
["images/bags.jpeg", "black bag next to the red bag", False, True],
["images/deer.jpeg", "deer jumping and running across the road", True, True],
["images/penn.jpeg", "sign board", True, False],
],
inputs=[input_image, object_search, show_masks, show_boxes],
)
gr.HTML("""
<div style="text-align:center">
<p>Developed by <a href='https://www.linkedin.com/in/dekay/'>Github and Huggingface: Volkopat</a></p>
<p>Powered by <a href='https://segment-anything.com'>Segment Anything</a> and <a href='https://arxiv.org/abs/2303.05499'>Grounding DINO</a></p>
<p>Just upload an image and enter the objects to detect, segment, crop, etc. That's all folks!</p>
<p>What's Zero-Shot? It means you can detect objects without any training samples!</p>
<p>This project is for demonstration purposes. Credits for State of the Art models go to Meta AI and IDEA Research.</p>
</div>
<style>
p {
margin-bottom: 10px;
font-size: 16px;
}
a {
color: #3867d6;
text-decoration: none;
}
a:hover {
text-decoration: underline;
}
</style>
""")
submit.click(fn=detect_objects,
inputs=[input_image, object_search, show_masks, show_boxes, crop_options],
outputs=[output_image])
block.launch(width=800)