|
import gradio as gr |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import transforms |
|
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry |
|
import segmentation_models_pytorch as smp |
|
|
|
def load_model(model_type): |
|
|
|
model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth") |
|
model.to(device='cuda') |
|
return SamAutomaticMaskGenerator(model) |
|
|
|
def segment_and_classify(image, model_type): |
|
model = load_model(model_type) |
|
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
|
|
|
|
masks = model.generate(image_cv) |
|
|
|
|
|
segments = [] |
|
|
|
|
|
for mask_data in masks: |
|
mask = mask_data['segmentation'] |
|
segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3]) |
|
segments.append(segment) |
|
|
|
|
|
|
|
return Image.fromarray(segments[0]) |
|
|
|
iface = gr.Interface( |
|
fn=segment_and_classify, |
|
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")], |
|
outputs=gr.outputs.Image(type="pil"), |
|
title="SAM Model Segmentation and Classification", |
|
description="Upload an image, select a model type, and receive the segmented and classified parts." |
|
) |
|
|
|
iface.launch() |