import gradio as gr import torch import cv2 import numpy as np from PIL import Image import matplotlib.pyplot as plt import io from ultralytics import FastSAM from ultralytics.models.fastsam import FastSAMPrompt # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" # Load FastSAM model model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt def fig2img(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True): # ... (keep the existing plot function as is) # This function doesn't need modification for our purposes def segment_image(input_image, object_name): try: if input_image is None: return None, "Please upload an image before submitting." input_image = Image.fromarray(input_image).convert("RGB") # Run FastSAM model with adjusted parameters everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.25, iou=0.7) # Prepare a Prompt Process object prompt_process = FastSAMPrompt(input_image, everything_results, device=device) # Use text prompt to segment the specified object results = prompt_process.text_prompt(text=object_name) if not results: return input_image, f"Could not find '{object_name}' in the image." # Post-process the masks for ann in results: if ann.masks is not None: masks = ann.masks.data if isinstance(masks[0], torch.Tensor): masks = np.array(masks.cpu()) for i, mask in enumerate(masks): # Apply more aggressive morphological operations kernel = np.ones((5,5), np.uint8) mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel) masks[i] = cv2.dilate(mask, kernel, iterations=2) ann.masks.data = masks # Plot the results result_image = plot(annotations=results, prompt_process=prompt_process) return result_image, f"Segmented '{object_name}' in the image." except Exception as e: return None, f"An error occurred: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="numpy", label="Upload an image"), gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)") ], outputs=[ gr.Image(type="pil", label="Segmented Image"), gr.Textbox(label="Status") ], title="FastSAM Segmentation with Object Specification", description="Upload an image and specify an object to segment using FastSAM." ) # Launch the interface iface.launch()