import gradio as gr import cv2 import torch import numpy as np from PIL import Image from torchvision import transforms from segment_anything import SamAutomaticMaskGenerator, sam_model_registry import segmentation_models_pytorch as smp def load_model(model_type): # Model loading simplified for clarity model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth") model.to(device='cuda') return SamAutomaticMaskGenerator(model) def segment_and_classify(image, model_type): model = load_model(model_type) image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Generate masks masks = model.generate(image_cv) # Prepare to store segments segments = [] # Loop through masks and extract segments for mask_data in masks: mask = mask_data['segmentation'] segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3]) # Apply mask to the image segments.append(segment) # Store the segment for classification # Here you would call the classification model (e.g., CLIP) # For now, let's just return the first segment for visualization return Image.fromarray(segments[0]) iface = gr.Interface( fn=segment_and_classify, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")], outputs=gr.outputs.Image(type="pil"), title="SAM Model Segmentation and Classification", description="Upload an image, select a model type, and receive the segmented and classified parts." ) iface.launch()