import gradio as gr import torch import cv2 import numpy as np from transformers import SamModel, SamProcessor from PIL import Image # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model and processor model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) processor = SamProcessor.from_pretrained("facebook/sam-vit-base") def segment_image(input_image, segment_anything): # Convert input_image to PIL Image input_image = Image.fromarray(input_image) if segment_anything: # Segment everything in the image inputs = processor(input_image, return_tensors="pt").to(device) else: # Use the center of the image as a point prompt height, width = input_image.size center_point = [[width // 2, height // 2]] inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device) # Generate masks with torch.no_grad(): outputs = model(**inputs) # Post-process masks masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) # Convert mask to numpy array if segment_anything: # Combine all masks combined_mask = np.any(masks[0].numpy() > 0.5, axis=0) else: # Use the first mask combined_mask = masks[0][0].numpy() > 0.5 # Overlay the mask on the original image result_image = np.array(input_image) mask_rgb = np.zeros_like(result_image) mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0) return result_image # Create Gradio interface iface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="numpy"), gr.Checkbox(label="Segment Everything") ], outputs=gr.Image(type="numpy"), title="Segment Anything Model (SAM) Image Segmentation", description="Upload an image and choose whether to segment everything or use a center point." ) # Launch the interface iface.launch()