import gradio as gr import torch import cv2 import numpy as np from transformers import SamModel, SamProcessor from PIL import Image # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model and processor model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) processor = SamProcessor.from_pretrained("facebook/sam-vit-base") def segment_image(input_image, segment_anything): try: if input_image is None: return None, "Please upload an image before submitting." # Convert input_image to PIL Image and ensure it's RGB input_image = Image.fromarray(input_image).convert("RGB") # Store original size original_size = input_image.size if not original_size or 0 in original_size: return None, "Invalid image size. Please upload a different image." # Process the image if segment_anything: inputs = processor(input_image, return_tensors="pt").to(device) else: width, height = original_size center_point = [[width // 2, height // 2]] inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device) # Generate masks with torch.no_grad(): outputs = model(**inputs) # Post-process masks masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) # Convert mask to numpy array if segment_anything: combined_mask = np.any(masks[0].numpy() > 0.5, axis=0) else: combined_mask = masks[0][0].numpy() > 0.5 # Ensure mask is 2D if combined_mask.ndim > 2: combined_mask = combined_mask.squeeze() # Resize mask to match original image size using PIL mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8)) mask_image = mask_image.resize(original_size, Image.NEAREST) combined_mask = np.array(mask_image) > 0 # Overlay the mask on the original image result_image = np.array(input_image) mask_rgb = np.zeros_like(result_image) mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0) return result_image, "Segmentation completed successfully." except Exception as e: return None, f"An error occurred: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="numpy", label="Upload an image"), gr.Checkbox(label="Segment Everything") ], outputs=[ gr.Image(type="numpy", label="Segmented Image"), gr.Textbox(label="Status") ], title="Segment Anything Model (SAM) Image Segmentation", description="Upload an image and choose whether to segment everything or use a center point." ) # Launch the interface iface.launch()