SegmentVision / app.py
sagar007's picture
Update app.py
3ba1061 verified
raw
history blame
2.64 kB
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def segment_image(input_image, segment_anything):
if input_image is None:
return None, "Please upload an image before submitting."
# Convert input_image to PIL Image
input_image = Image.fromarray(input_image)
# Store original size
original_size = input_image.size
if segment_anything:
# Segment everything in the image
inputs = processor(input_image, return_tensors="pt").to(device)
else:
# Use the center of the image as a point prompt
width, height = input_image.size
center_point = [[width // 2, height // 2]]
inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
# Convert mask to numpy array and resize to match original image
if segment_anything:
# Combine all masks
combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
else:
# Use the first mask
combined_mask = masks[0][0].numpy() > 0.5
# Resize mask to match original image size
combined_mask = cv2.resize(combined_mask.astype(np.uint8), original_size[::-1]) > 0
# Overlay the mask on the original image
result_image = np.array(input_image)
mask_rgb = np.zeros_like(result_image)
mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
return result_image, "Segmentation completed successfully."
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy", label="Upload an image"),
gr.Checkbox(label="Segment Everything")
],
outputs=[
gr.Image(type="numpy", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="Segment Anything Model (SAM) Image Segmentation",
description="Upload an image and choose whether to segment everything or use a center point."
)
# Launch the interface
iface.launch()