Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
from transformers import SamModel, SamProcessor | |
from PIL import Image | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model and processor | |
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
def segment_image(input_image, segment_anything): | |
if input_image is None: | |
return None, "Please upload an image before submitting." | |
# Convert input_image to PIL Image | |
input_image = Image.fromarray(input_image) | |
# Store original size | |
original_size = input_image.size | |
if segment_anything: | |
# Segment everything in the image | |
inputs = processor(input_image, return_tensors="pt").to(device) | |
else: | |
# Use the center of the image as a point prompt | |
width, height = input_image.size | |
center_point = [[width // 2, height // 2]] | |
inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device) | |
# Generate masks | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Post-process masks | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
) | |
# Convert mask to numpy array and resize to match original image | |
if segment_anything: | |
# Combine all masks | |
combined_mask = np.any(masks[0].numpy() > 0.5, axis=0) | |
else: | |
# Use the first mask | |
combined_mask = masks[0][0].numpy() > 0.5 | |
# Resize mask to match original image size | |
combined_mask = cv2.resize(combined_mask.astype(np.uint8), original_size[::-1]) > 0 | |
# Overlay the mask on the original image | |
result_image = np.array(input_image) | |
mask_rgb = np.zeros_like(result_image) | |
mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask | |
result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0) | |
return result_image, "Segmentation completed successfully." | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=segment_image, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload an image"), | |
gr.Checkbox(label="Segment Everything") | |
], | |
outputs=[ | |
gr.Image(type="numpy", label="Segmented Image"), | |
gr.Textbox(label="Status") | |
], | |
title="Segment Anything Model (SAM) Image Segmentation", | |
description="Upload an image and choose whether to segment everything or use a center point." | |
) | |
# Launch the interface | |
iface.launch() |