SegmentVision / app.py
sagar007's picture
Update app.py
e9cd6fd verified
raw
history blame
1.3 kB
import gradio as gr
import torch
import cv2
import numpy as np
from fastsam import FastSAM, FastSAMPrompt
# Load the FastSAM model
model = FastSAM('FastSAM-x.pt')
def segment_image(input_image, points):
# Prepare the image
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
# Run the model
everything_results = model(input_image, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
# Prepare prompts
prompt_process = FastSAMPrompt(input_image, everything_results, device='cpu')
# Generate mask based on points
ann = prompt_process.point_prompt(points=points, pointlabel=[1] * len(points))
# Overlay the mask on the original image
result_image = input_image.copy()
mask = ann[0].astype(bool)
result_image[mask] = result_image[mask] * 0.5 + np.array([255, 0, 0]) * 0.5
return result_image
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy"),
gr.Image(type="numpy", tool="sketch", brush_radius=5, label="Click on objects to segment")
],
outputs=gr.Image(type="numpy"),
title="FastSAM Image Segmentation",
description="Click on objects in the image to segment them using FastSAM."
)
# Launch the interface
iface.launch()