Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
from ultralytics import FastSAM | |
from ultralytics.models.fastsam import FastSAMPrompt | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load FastSAM model | |
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def plot_masks(annotations, output_shape): | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(annotations[0].orig_img) | |
for ann in annotations: | |
for mask in ann.masks.data: | |
mask = cv2.resize(mask.cpu().numpy().astype('uint8'), output_shape[::-1]) | |
masked = np.ma.masked_where(mask == 0, mask) | |
ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet')) | |
ax.axis('off') | |
plt.close() | |
return fig2img(fig) | |
def segment_everything(input_image): | |
try: | |
if input_image is None: | |
return None, "Please upload an image before submitting." | |
input_image = Image.fromarray(input_image).convert("RGB") | |
# Run FastSAM model in "everything" mode | |
everything_results = model(input_image, device=device, retina_masks=True, imgsz=1024, conf=0.25, iou=0.9, agnostic_nms=True) | |
# Prepare a Prompt Process object | |
prompt_process = FastSAMPrompt(input_image, everything_results, device=device) | |
# Get everything segmentation | |
ann = prompt_process.everything_prompt() | |
# Plot the results | |
result_image = plot_masks(ann, input_image.size) | |
return result_image, f"Segmented everything in the image. Found {len(ann[0].masks)} objects." | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=segment_everything, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload an image") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmented Image"), | |
gr.Textbox(label="Status") | |
], | |
title="FastSAM Everything Segmentation", | |
description="Upload an image to segment all objects using FastSAM." | |
) | |
# Launch the interface | |
iface.launch() |