BioMike commited on
Commit
1bfa982
·
verified ·
1 Parent(s): 4875d48

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from src.model import ClipSegMultiClassModel
7
+ from src.config import ClipSegMultiClassConfig
8
+
9
+ # === Load model ===
10
+ class_labels = ["background", "Pig", "Horse", "Sheep"]
11
+ label2color = {
12
+ 0: [0, 0, 0],
13
+ 1: [255, 0, 0],
14
+ 2: [0, 255, 0],
15
+ 3: [0, 0, 255],
16
+ }
17
+
18
+ config = ClipSegMultiClassConfig(
19
+ class_labels=class_labels,
20
+ label2color=label2color,
21
+ model="CIDAS/clipseg-rd64-refined",
22
+ )
23
+
24
+ model = ClipSegMultiClassModel.from_pretrained("BioMike/clipsegmulticlass_v1")
25
+
26
+ model.eval()
27
+
28
+ def colorize_mask(mask_tensor, label2color):
29
+ mask = mask_tensor.squeeze().cpu().numpy()
30
+ h, w = mask.shape
31
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
32
+ for class_id, color in label2color.items():
33
+ color_mask[mask == class_id] = color
34
+ return color_mask
35
+
36
+ def segment_with_legend(input_img):
37
+ if isinstance(input_img, str):
38
+ input_img = Image.open(input_img).convert("RGB")
39
+ elif isinstance(input_img, np.ndarray):
40
+ input_img = Image.fromarray(input_img).convert("RGB")
41
+
42
+ pred_mask = model.predict(input_img)
43
+ color_mask = colorize_mask(pred_mask, label2color)
44
+ overlay = Image.blend(input_img.resize((color_mask.shape[1], color_mask.shape[0])), Image.fromarray(color_mask), alpha=0.5)
45
+
46
+ fig, ax = plt.subplots(figsize=(8, 6))
47
+ ax.imshow(overlay)
48
+ ax.axis("off")
49
+
50
+ legend_patches = [
51
+ plt.Line2D([0], [0], marker='o', color='w',
52
+ label=label,
53
+ markerfacecolor=np.array(color) / 255.0,
54
+ markersize=10)
55
+ for label, color in zip(class_labels, label2color.values())
56
+ ]
57
+ ax.legend(handles=legend_patches, loc='lower right', framealpha=0.8)
58
+
59
+ return fig