taher30 commited on
Commit
429f61d
·
verified ·
1 Parent(s): b7b5866
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
8
+ import segmentation_models_pytorch as smp
9
+
10
+ def load_model(model_type):
11
+ # Model loading simplified for clarity
12
+ model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth")
13
+ model.to(device='cuda')
14
+ return SamAutomaticMaskGenerator(model)
15
+
16
+ def segment_and_classify(image, model_type):
17
+ model = load_model(model_type)
18
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
19
+
20
+ # Generate masks
21
+ masks = model.generate(image_cv)
22
+
23
+ # Prepare to store segments
24
+ segments = []
25
+
26
+ # Loop through masks and extract segments
27
+ for mask_data in masks:
28
+ mask = mask_data['segmentation']
29
+ segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3]) # Apply mask to the image
30
+ segments.append(segment) # Store the segment for classification
31
+
32
+ # Here you would call the classification model (e.g., CLIP)
33
+ # For now, let's just return the first segment for visualization
34
+ return Image.fromarray(segments[0])
35
+
36
+ iface = gr.Interface(
37
+ fn=segment_and_classify,
38
+ inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
39
+ outputs=gr.outputs.Image(type="pil"),
40
+ title="SAM Model Segmentation and Classification",
41
+ description="Upload an image, select a model type, and receive the segmented and classified parts."
42
+ )
43
+
44
+ iface.launch()