AutoSeg / app.py
taher30's picture
New App
429f61d verified
raw
history blame
1.58 kB
import gradio as gr
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import segmentation_models_pytorch as smp
def load_model(model_type):
# Model loading simplified for clarity
model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth")
model.to(device='cuda')
return SamAutomaticMaskGenerator(model)
def segment_and_classify(image, model_type):
model = load_model(model_type)
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Generate masks
masks = model.generate(image_cv)
# Prepare to store segments
segments = []
# Loop through masks and extract segments
for mask_data in masks:
mask = mask_data['segmentation']
segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3]) # Apply mask to the image
segments.append(segment) # Store the segment for classification
# Here you would call the classification model (e.g., CLIP)
# For now, let's just return the first segment for visualization
return Image.fromarray(segments[0])
iface = gr.Interface(
fn=segment_and_classify,
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
outputs=gr.outputs.Image(type="pil"),
title="SAM Model Segmentation and Classification",
description="Upload an image, select a model type, and receive the segmented and classified parts."
)
iface.launch()