Spaces:
Sleeping
Sleeping
File size: 5,023 Bytes
dfdcd97 a3ee867 e9cd6fd c95f3e0 9a34a8b e0d4d2f c95f3e0 9a34a8b 3cd1243 9a34a8b e0d4d2f 9a34a8b 72f4c5c 26c0f04 3cd1243 73989e5 26c0f04 2dd8fe8 26c0f04 9a34a8b 3cd1243 9a34a8b 73989e5 9a34a8b 73989e5 2dd8fe8 9a34a8b 73989e5 3cd1243 3ba1061 73989e5 e0d4d2f e9cd6fd e0d4d2f e9cd6fd 3ba1061 3cd1243 e9cd6fd 3ba1061 9a34a8b 3ba1061 9a34a8b e0d4d2f e9cd6fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load FastSAM model
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True):
for ann in annotations:
image = ann.orig_img[..., ::-1] # BGR to RGB
original_h, original_w = ann.orig_shape
fig = plt.figure(figsize=(original_w / 100, original_h / 100))
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if ann.masks is not None:
masks = ann.masks.data
if better_quality:
if isinstance(masks[0], torch.Tensor):
masks = np.array(masks.cpu())
for i, mask in enumerate(masks):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
prompt_process.fast_show_mask(
masks,
plt.gca(),
random_color=mask_random_color,
bbox=None,
points=None,
pointlabel=None,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if with_contours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(masks):
mask = mask.astype(np.uint8)
if not retina:
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
plt.axis("off")
plt.close()
return fig2img(fig)
def segment_image(input_image, object_name):
try:
if input_image is None:
return None, "Please upload an image before submitting."
input_image = Image.fromarray(input_image).convert("RGB")
# Run FastSAM model with adjusted parameters
everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.25, iou=0.7)
# Prepare a Prompt Process object
prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
# Use text prompt to segment the specified object
results = prompt_process.text_prompt(text=object_name)
if not results:
return input_image, f"Could not find '{object_name}' in the image."
# Post-process the masks
for ann in results:
if ann.masks is not None:
masks = ann.masks.data
if isinstance(masks[0], torch.Tensor):
masks = np.array(masks.cpu())
for i, mask in enumerate(masks):
# Apply more aggressive morphological operations
kernel = np.ones((5,5), np.uint8)
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
masks[i] = cv2.dilate(mask, kernel, iterations=2)
ann.masks.data = masks
# Plot the results
result_image = plot(annotations=results, prompt_process=prompt_process)
return result_image, f"Segmented '{object_name}' in the image."
except Exception as e:
return None, f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy", label="Upload an image"),
gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
],
outputs=[
gr.Image(type="pil", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="FastSAM Segmentation with Object Specification",
description="Upload an image and specify an object to segment using FastSAM."
)
# Launch the interface
iface.launch() |