eksemyashkina commited on
Commit
f2c8754
·
verified ·
1 Parent(s): b108d0c

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +121 -0
  2. losses_config.json +14 -0
  3. requirements.txt +11 -0
  4. weights/dino.pth +3 -0
  5. weights/unet.pth +3 -0
  6. weights/vit.pth +3 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import PIL.Image, PIL.ImageOps
3
+ import torch
4
+ import torchvision.transforms.functional as F
5
+ from matplotlib import cm
6
+ from matplotlib.colors import to_hex
7
+ import numpy as np
8
+
9
+ from src.models.dino import DINOSegmentationModel
10
+ from src.models.vit import ViTSegmentation
11
+ from src.models.unet import UNet
12
+ from src.utils import get_transform
13
+
14
+
15
+ device = torch.device("cpu")
16
+ model_weight1 = "weights/dino.pth"
17
+ model_weight2 = "weights/vit.pth"
18
+ model_weight3 = "weights/unet.pth"
19
+
20
+ model1 = DINOSegmentationModel()
21
+ model1.segmentation_head.load_state_dict(torch.load(model_weight1, map_location=device))
22
+ model1.eval()
23
+ model2 = ViTSegmentation()
24
+ model2.segmentation_head.load_state_dict(torch.load(model_weight2, map_location=device))
25
+ model2.eval()
26
+ model3 = UNet()
27
+ model3.load_state_dict(torch.load(model_weight3, map_location=device))
28
+ model3.eval()
29
+
30
+ mask_labels = {
31
+ "0": "Background", "1": "Person", "2": "Skin", "3": "Left-brow", "4": "Right-brow",
32
+ "5": "Left-eye", "6": "Right-eye", "7": "Lips", "8": "Teeth"
33
+ }
34
+
35
+ color_map = cm.get_cmap('tab20', 9)
36
+ label_colors = {label: to_hex(color_map(idx / len(mask_labels))[:3]) for idx, label in enumerate(mask_labels)}
37
+ fixed_colors = np.array([color_map(i)[:3] for i in range(9)]) * 255
38
+
39
+
40
+ def mask_to_color(mask: np.ndarray) -> np.ndarray:
41
+ h, w = mask.shape
42
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
43
+ for class_idx in range(9):
44
+ color_mask[mask == class_idx] = fixed_colors[class_idx]
45
+ return color_mask
46
+
47
+
48
+ def segment_image(image, model_name: str) -> PIL.Image:
49
+ if model_name == "DINO":
50
+ model = model1
51
+ elif model_name == "ViT":
52
+ model = model2
53
+ else:
54
+ model = model3
55
+
56
+ original_width, original_height = image.size
57
+ transform = get_transform(model.mean, model.std)
58
+ input_tensor = transform(image).unsqueeze(0)
59
+
60
+ with torch.no_grad():
61
+ mask = model(input_tensor)
62
+ mask = torch.argmax(mask.squeeze(), dim=0).cpu().numpy()
63
+
64
+ mask_image = mask_to_color(mask)
65
+
66
+ mask_image = PIL.Image.fromarray(mask_image)
67
+ mask_aspect_ratio = mask_image.width / mask_image.height
68
+
69
+ new_height = original_height
70
+ new_width = int(new_height * mask_aspect_ratio)
71
+ mask_image = mask_image.resize((new_width, new_height), PIL.Image.Resampling.NEAREST)
72
+
73
+ final_mask = PIL.Image.new("RGB", (original_width, original_height))
74
+ offset = ((original_width - new_width) // 2, 0)
75
+ final_mask.paste(mask_image, offset)
76
+
77
+ return final_mask
78
+
79
+ def generate_legend_html_compact() -> str:
80
+ legend_html = """
81
+ <div style='display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;'>
82
+ """
83
+ for idx, (label, color) in enumerate(label_colors.items()):
84
+ legend_html += f"""
85
+ <div style='display: flex; align-items: center; justify-content: center;
86
+ padding: 5px 10px; border: 1px solid {color};
87
+ background-color: {color}; border-radius: 5px;
88
+ color: white; font-size: 12px; text-align: center;'>
89
+ {mask_labels[label]}
90
+ </div>
91
+ """
92
+ legend_html += "</div>"
93
+ return legend_html
94
+
95
+ examples = [
96
+ ["assets/images_examples/image1.jpg"],
97
+ ["assets/images_examples/image2.jpg"],
98
+ ["assets/images_examples/image3.jpg"]
99
+ ]
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("## Face Segmentation")
103
+ with gr.Row():
104
+ with gr.Column():
105
+ pic = gr.Image(label="Upload Human Image", type="pil", height=400, width=400)
106
+ model_choice = gr.Dropdown(choices=["DINO", "ViT", "UNet"], label="Select Model", value="DINO")
107
+ with gr.Row():
108
+ with gr.Column(scale=1):
109
+ predict_btn = gr.Button("Predict")
110
+ with gr.Column(scale=1):
111
+ clear_btn = gr.Button("Clear")
112
+
113
+ with gr.Column():
114
+ output = gr.Image(label="Mask", type="pil", height=400, width=400)
115
+ legend = gr.HTML(label="Legend", value=generate_legend_html_compact())
116
+
117
+ predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output, api_name="predict")
118
+ clear_btn.click(lambda: (None, None), outputs=[pic, output])
119
+ gr.Examples(examples=examples, inputs=[pic])
120
+
121
+ demo.launch()
losses_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cross_entropy": 0.33,
3
+ "SSLoss_v2": 0.0,
4
+ "ExpLog_loss": 0.0,
5
+ "LovaszSoftmax": 0.0,
6
+ "TopKLoss": 0.33,
7
+ "WeightedCrossEntropyLoss": 0.0,
8
+ "SoftDiceLoss_v2": 0.0,
9
+ "IoULoss_v2": 0.0,
10
+ "TverskyLoss_v2": 0.0,
11
+ "FocalTversky_loss_v2": 0.0,
12
+ "AsymLoss_v2": 0.0,
13
+ "FocalLoss": 0.33
14
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchvision==0.19.1
3
+ kaggle==1.6.17
4
+ wandb==0.18.5
5
+ gradio==5.4.0
6
+ datasets==3.1.0
7
+ accelerate==1.1.0
8
+ opencv-python==4.10.0.84
9
+ scipy==1.14.1
10
+ transformers==4.46.2
11
+ matplotlib==3.10.0
weights/dino.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7bbb568a7aaa755f68da7cbeb493cc1a7c002e1659ccda808564a8a1bd075fb
3
+ size 8269400
weights/unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c0699b44ac999ebe77a236bd20a33a71f2507271aa2de2d1559f305ecf27c8e
3
+ size 200940690
weights/vit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2539e8dbf9029370ba143e39ac9355215dfa613876b0c1ea9c09d03ae5e1ec09
3
+ size 11808344