Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
from ultralytics import FastSAM | |
from ultralytics.models.fastsam import FastSAMPrompt | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load FastSAM model | |
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True): | |
# ... (keep the existing plot function as is) | |
# This function doesn't need modification for our purposes | |
def segment_image(input_image, object_name): | |
try: | |
if input_image is None: | |
return None, "Please upload an image before submitting." | |
input_image = Image.fromarray(input_image).convert("RGB") | |
# Run FastSAM model with adjusted parameters | |
everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.25, iou=0.7) | |
# Prepare a Prompt Process object | |
prompt_process = FastSAMPrompt(input_image, everything_results, device=device) | |
# Use text prompt to segment the specified object | |
results = prompt_process.text_prompt(text=object_name) | |
if not results: | |
return input_image, f"Could not find '{object_name}' in the image." | |
# Post-process the masks | |
for ann in results: | |
if ann.masks is not None: | |
masks = ann.masks.data | |
if isinstance(masks[0], torch.Tensor): | |
masks = np.array(masks.cpu()) | |
for i, mask in enumerate(masks): | |
# Apply more aggressive morphological operations | |
kernel = np.ones((5,5), np.uint8) | |
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel) | |
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel) | |
masks[i] = cv2.dilate(mask, kernel, iterations=2) | |
ann.masks.data = masks | |
# Plot the results | |
result_image = plot(annotations=results, prompt_process=prompt_process) | |
return result_image, f"Segmented '{object_name}' in the image." | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=segment_image, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload an image"), | |
gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmented Image"), | |
gr.Textbox(label="Status") | |
], | |
title="FastSAM Segmentation with Object Specification", | |
description="Upload an image and specify an object to segment using FastSAM." | |
) | |
# Launch the interface | |
iface.launch() |