SegmentVision / app.py
sagar007's picture
Update app.py
2dd8fe8 verified
raw
history blame
3 kB
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load FastSAM model
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True):
# ... (keep the existing plot function as is)
# This function doesn't need modification for our purposes
def segment_image(input_image, object_name):
try:
if input_image is None:
return None, "Please upload an image before submitting."
input_image = Image.fromarray(input_image).convert("RGB")
# Run FastSAM model with adjusted parameters
everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.25, iou=0.7)
# Prepare a Prompt Process object
prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
# Use text prompt to segment the specified object
results = prompt_process.text_prompt(text=object_name)
if not results:
return input_image, f"Could not find '{object_name}' in the image."
# Post-process the masks
for ann in results:
if ann.masks is not None:
masks = ann.masks.data
if isinstance(masks[0], torch.Tensor):
masks = np.array(masks.cpu())
for i, mask in enumerate(masks):
# Apply more aggressive morphological operations
kernel = np.ones((5,5), np.uint8)
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
masks[i] = cv2.dilate(mask, kernel, iterations=2)
ann.masks.data = masks
# Plot the results
result_image = plot(annotations=results, prompt_process=prompt_process)
return result_image, f"Segmented '{object_name}' in the image."
except Exception as e:
return None, f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy", label="Upload an image"),
gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
],
outputs=[
gr.Image(type="pil", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="FastSAM Segmentation with Object Specification",
description="Upload an image and specify an object to segment using FastSAM."
)
# Launch the interface
iface.launch()