eksemyashkina commited on
Commit
61123b8
·
verified ·
1 Parent(s): 9b6e42b

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +87 -0
  2. app.py +128 -0
  3. losses_config.json +14 -0
  4. requirements.txt +11 -0
README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Clothes Segmentation
2
+
3
+ ![Sample Images and Segmentation Masks from Dataset](assets/dataset_examples.png)
4
+
5
+ This project provides a solution for segmenting clothes into 18 categories using DINO, ViT and UNet models.
6
+
7
+ * DINO: Pretrained backbone with a segmentation head
8
+ * https://arxiv.org/abs/2104.14294
9
+ * https://huggingface.co/facebook/dinov2-small
10
+ * ViT: Pretrained vision transformer with a segmentation head
11
+ * https://arxiv.org/abs/2010.11929
12
+ * https://huggingface.co/google/vit-base-patch16-224
13
+ * UNet: Custom implementation
14
+ * https://arxiv.org/abs/1505.04597
15
+
16
+ Gradio is used for building a web interface and Weights & Biases for experiments tracking.
17
+
18
+ ## Installation
19
+
20
+ 1. Clone the repository:
21
+ ```bash
22
+ git clone https://github.com/your-project/clothes-segmentation.git
23
+ cd plant-classifier
24
+ ```
25
+
26
+ 2. Create and activate a virtual environment:
27
+ ```bash
28
+ python -m venv venv
29
+ source venv/bin/activate
30
+ ```
31
+
32
+ 3. Install dependencies:
33
+ ```bash
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ## Usage
38
+
39
+ ### Training the Model
40
+ To train a model, specify one of the following using the --model argument: **dino**, **vit** or **unet**.
41
+ ```bash
42
+ python src/train.py --model dino
43
+ python src/train.py --model vit
44
+ python src/train.py --model unet
45
+ ```
46
+
47
+ You can also adjust other parameters, such as the number of epochs, batch size, and learning rate, by adding additional arguments. For example:
48
+ ```bash
49
+ python src/train.py --model unet --num-epochs 20 --batch-size 16 --learning-rate 0.001
50
+ ```
51
+
52
+ ### Launching the Gradio Interface
53
+ ```bash
54
+ python app.py
55
+ ```
56
+
57
+ Once the interface is running, you can select a model, upload an image and view the segmentation mask.
58
+
59
+ ![Web Interface Screen](assets/spaces_screen.jpg)
60
+
61
+ #### добавить ссылку
62
+
63
+ ## Results
64
+
65
+ | Model | Test Micro Recall | Test Micro Precision | Test Macro Precision | Test Macro Recall | Test Accuracy | Test Loss | Train Micro Recall | Train Micro Precision | Train Macro Precision | Train Macro Recall | Train Accuracy | Train Loss |
66
+ |------------|-------------------|----------------------|----------------------|-------------------|---------------|-----------|--------------------|-----------------------|-----------------------|--------------------|----------------|------------|
67
+ | DINO | 0.94986 | 0.94986 | 0.71364 | 0.67052 | 0.94986 | 0.53124 | 0.97019 | 0.97019 | 0.78185 | 0.72336 | 0.97019 | 0.30441 |
68
+ | ViT | 0.9358 | 0.9358 | 0.63939 | 0.58365 | 0.9358 | 0.71193 | 0.96734 | 0.96734 | 0.74418 | 0.66295 | 0.96734 | 0.31166 |
69
+ | UNet | 0.95798 | 0.95798 | 0.76354 | 0.7289 | 0.95798 | 0.56764 | 0.98035 | 0.98035 | 0.82934 | 0.82688 | 0.98035 | 0.25301 |
70
+
71
+ ### Training Results of DINO
72
+
73
+ ![DINO_test](assets/dino_test_plots.png)
74
+
75
+ ![DINO_train](assets/dino_train_plots.png)
76
+
77
+ ### Training Results of ViT
78
+
79
+ ![ViT_test](assets/vit_test_plots.png)
80
+
81
+ ![ViT_train](assets/vit_train_plots.png)
82
+
83
+ ### Training Results of UNet
84
+
85
+ ![UNet_test](assets/unet_test_plots.png)
86
+
87
+ ![UNet_train](assets/unet_train_plots.png)
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import gradio as gr
3
+ import json
4
+ import PIL.Image, PIL.ImageOps
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ from matplotlib import cm
8
+ from matplotlib.colors import to_hex
9
+ import numpy as np
10
+
11
+ from src.models.dino import DINOSegmentationModel
12
+ from src.models.vit import ViTSegmentation
13
+ from src.models.unet import UNet
14
+ from src.utils import get_transform
15
+
16
+
17
+ device = torch.device("cpu")
18
+ model_weight1 = "weights/dino.pth"
19
+ model_weight2 = "weights/vit.pth"
20
+ model_weight3 = "weights/unet.pth"
21
+
22
+ model1 = DINOSegmentationModel()
23
+ model1.segmentation_head.load_state_dict(torch.load(model_weight1, map_location=device))
24
+ model1.eval()
25
+ model2 = ViTSegmentation()
26
+ model2.segmentation_head.load_state_dict(torch.load(model_weight2, map_location=device))
27
+ model2.eval()
28
+ model3 = UNet()
29
+ model3.load_state_dict(torch.load(model_weight3, map_location=device))
30
+ model3.eval()
31
+
32
+ mask_labels = {
33
+ "0": "Background", "1": "Hat", "2": "Hair", "3": "Sunglasses", "4": "Upper-clothes",
34
+ "5": "Skirt", "6": "Pants", "7": "Dress", "8": "Belt", "9": "Right-shoe",
35
+ "10": "Left-shoe", "11": "Face", "12": "Right-leg", "13": "Left-leg",
36
+ "14": "Right-arm", "15": "Left-arm", "16": "Bag", "17": "Scarf"
37
+ }
38
+
39
+ color_map = cm.get_cmap('tab20', 18)
40
+ label_colors = {label: to_hex(color_map(idx / len(mask_labels))[:3]) for idx, label in enumerate(mask_labels)}
41
+ fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
42
+
43
+
44
+ def mask_to_color(mask: np.ndarray) -> np.ndarray:
45
+ h, w = mask.shape
46
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
47
+ for class_idx in range(18):
48
+ color_mask[mask == class_idx] = fixed_colors[class_idx]
49
+ return color_mask
50
+
51
+
52
+ def segment_image(image, model_name: str) -> PIL.Image:
53
+ if model_name == "DINO":
54
+ model = model1
55
+ elif model_name == "ViT":
56
+ model = model2
57
+ else:
58
+ model = model3
59
+
60
+ original_width, original_height = image.size
61
+ transform = get_transform(model.mean, model.std)
62
+ input_tensor = transform(image).unsqueeze(0)
63
+
64
+ with torch.no_grad():
65
+ mask = model(input_tensor)
66
+ mask = torch.argmax(mask.squeeze(), dim=0).cpu().numpy()
67
+
68
+ mask_image = mask_to_color(mask)
69
+
70
+ mask_image = PIL.Image.fromarray(mask_image)
71
+ mask_aspect_ratio = mask_image.width / mask_image.height
72
+
73
+ new_height = original_height
74
+ new_width = int(new_height * mask_aspect_ratio)
75
+ mask_image = mask_image.resize((new_width, new_height), PIL.Image.Resampling.NEAREST)
76
+
77
+ final_mask = PIL.Image.new("RGB", (original_width, original_height))
78
+ offset = ((original_width - new_width) // 2, 0)
79
+ final_mask.paste(mask_image, offset)
80
+
81
+ return final_mask
82
+
83
+
84
+ def generate_legend_html_compact() -> str:
85
+ legend_html = """
86
+ <div style='display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;'>
87
+ """
88
+ for idx, (label, color) in enumerate(label_colors.items()):
89
+ legend_html += f"""
90
+ <div style='display: flex; align-items: center; justify-content: center;
91
+ padding: 5px 10px; border: 1px solid {color};
92
+ background-color: {color}; border-radius: 5px;
93
+ color: white; font-size: 12px; text-align: center;'>
94
+ {mask_labels[label]}
95
+ </div>
96
+ """
97
+ legend_html += "</div>"
98
+ return legend_html
99
+
100
+
101
+ examples = [
102
+ ["assets/images_examples/image1.jpg"],
103
+ ["assets/images_examples/image2.jpg"],
104
+ ["assets/images_examples/image3.jpg"]
105
+ ]
106
+
107
+
108
+ with gr.Blocks() as demo:
109
+ gr.Markdown("## Clothes Segmentation")
110
+ with gr.Row():
111
+ with gr.Column():
112
+ pic = gr.Image(label="Upload Human Image", type="pil", height=300, width=300)
113
+ model_choice = gr.Dropdown(choices=["DINO", "ViT", "UNet"], label="Select Model", value="DINO")
114
+ with gr.Row():
115
+ with gr.Column(scale=1):
116
+ predict_btn = gr.Button("Predict")
117
+ with gr.Column(scale=1):
118
+ clear_btn = gr.Button("Clear")
119
+
120
+ with gr.Column():
121
+ output = gr.Image(label="Mask", type="pil", height=300, width=300)
122
+ legend = gr.HTML(label="Legend", value=generate_legend_html_compact())
123
+
124
+ predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output, api_name="predict")
125
+ clear_btn.click(lambda: (None, None), outputs=[pic, output])
126
+ gr.Examples(examples=examples, inputs=[pic])
127
+
128
+ demo.launch()
losses_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cross_entropy": 0.33,
3
+ "SSLoss_v2": 0.0,
4
+ "ExpLog_loss": 0.0,
5
+ "LovaszSoftmax": 0.0,
6
+ "TopKLoss": 0.33,
7
+ "WeightedCrossEntropyLoss": 0.0,
8
+ "SoftDiceLoss_v2": 0.0,
9
+ "IoULoss_v2": 0.0,
10
+ "TverskyLoss_v2": 0.0,
11
+ "FocalTversky_loss_v2": 0.0,
12
+ "AsymLoss_v2": 0.0,
13
+ "FocalLoss": 0.33
14
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchvision==0.19.1
3
+ kaggle==1.6.17
4
+ wandb==0.18.5
5
+ gradio==5.4.0
6
+ datasets==3.1.0
7
+ accelerate==1.1.0
8
+ opencv-python==4.10.0.84
9
+ scipy==1.14.1
10
+ transformers==4.46.2
11
+ matplotlib==3.10.0