SegmentVision / app.py
sagar007's picture
Update app.py
4f39124 verified
raw
history blame
2.3 kB
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load FastSAM model
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot_masks(annotations, output_shape):
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(annotations[0].orig_img)
for ann in annotations:
for mask in ann.masks.data:
mask = cv2.resize(mask.cpu().numpy().astype('uint8'), output_shape[::-1])
masked = np.ma.masked_where(mask == 0, mask)
ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet'))
ax.axis('off')
plt.close()
return fig2img(fig)
def segment_everything(input_image):
try:
if input_image is None:
return None, "Please upload an image before submitting."
input_image = Image.fromarray(input_image).convert("RGB")
# Run FastSAM model in "everything" mode
everything_results = model(input_image, device=device, retina_masks=True, imgsz=1024, conf=0.25, iou=0.9, agnostic_nms=True)
# Prepare a Prompt Process object
prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
# Get everything segmentation
ann = prompt_process.everything_prompt()
# Plot the results
result_image = plot_masks(ann, input_image.size)
return result_image, f"Segmented everything in the image. Found {len(ann[0].masks)} objects."
except Exception as e:
return None, f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=segment_everything,
inputs=[
gr.Image(type="numpy", label="Upload an image")
],
outputs=[
gr.Image(type="pil", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="FastSAM Everything Segmentation",
description="Upload an image to segment all objects using FastSAM."
)
# Launch the interface
iface.launch()