XavierJiezou commited on
Commit
726f933
verified
1 Parent(s): 2dd4dfc

Create vis_model_plus.py

Browse files
Files changed (1) hide show
  1. visualization/code/vis_model_plus.py +183 -0
visualization/code/vis_model_plus.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import argparse
3
+ import os
4
+ from typing import Tuple, List
5
+ import numpy as np
6
+ from mmeval import MeanIoU
7
+ from PIL import Image
8
+ from matplotlib import pyplot as plt
9
+ from mmseg.apis import MMSegInferencer
10
+ from vegseg.datasets import GrassDataset
11
+ from vegseg import models
12
+
13
+
14
+ def get_iou(pred: np.ndarray, gt: np.ndarray, num_classes=2):
15
+ pred = pred[np.newaxis]
16
+ gt = gt[np.newaxis]
17
+ miou = MeanIoU(num_classes=num_classes)
18
+ result = miou(pred, gt)
19
+ return result["mIoU"] * 100
20
+
21
+
22
+ def get_args() -> Tuple[str, str, int]:
23
+ """
24
+ get args
25
+ return:
26
+ --device: device to use.
27
+ --dataset_path: dataset path.
28
+ --output_path: output path for saving.
29
+ """
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--device", type=str, default="cuda:4")
32
+ parser.add_argument("--dataset_path", type=str, default="data/grass")
33
+ args = parser.parse_args()
34
+ return args.device, args.dataset_path
35
+
36
+
37
+ def give_color_to_mask(
38
+ mask: Image.Image | np.ndarray, palette: List[int]
39
+ ) -> Image.Image:
40
+ """
41
+ Args:
42
+ mask: mask to color, numpy array or PIL Image.
43
+ palette: palette of dataset.
44
+ return:
45
+ mask: mask with color.
46
+ """
47
+ if isinstance(mask, np.ndarray):
48
+ mask = Image.fromarray(mask)
49
+ mask = mask.convert("P")
50
+ mask.putpalette(palette)
51
+ return mask
52
+
53
+
54
+ def get_image_and_mask_paths(
55
+ dataset_path: str, num: int
56
+ ) -> Tuple[List[str], List[str]]:
57
+ """
58
+ get image and mask paths from dataset path.
59
+ return:
60
+ image_paths: list of image paths.
61
+ mask_paths: list of mask paths.
62
+ """
63
+ image_paths = glob(os.path.join(dataset_path, "img_dir", "*", "*.tif"))
64
+ if num != -1:
65
+ image_paths = image_paths[:num]
66
+ mask_paths = [
67
+ filename.replace("tif", "png").replace("img_dir", "ann_dir")
68
+ for filename in image_paths
69
+ ]
70
+ return image_paths, mask_paths
71
+
72
+
73
+ def get_palette() -> List[int]:
74
+ """
75
+ get palette of dataset.
76
+ return:
77
+ palette: list of palette.
78
+ """
79
+ palette = []
80
+ palette_list = GrassDataset.METAINFO["palette"]
81
+ for palette_item in palette_list:
82
+ palette.extend(palette_item)
83
+ return palette
84
+
85
+
86
+ def init_all_models(models_paths: List[str], device: str):
87
+ """
88
+ init all models
89
+ Args:
90
+ models_path (str): path to all models.
91
+ device (str): device to use.
92
+ Return:
93
+ models (dict): dict of models.
94
+ """
95
+ models = {}
96
+ for model_path in models_paths:
97
+ print(model_path)
98
+ config_path = glob(os.path.join(model_path, "*.py"))[0]
99
+ weight_path = glob(os.path.join(model_path, "best_mIoU_iter_*.pth"))[0]
100
+ inference = MMSegInferencer(
101
+ config_path,
102
+ weight_path,
103
+ device=device,
104
+ classes=GrassDataset.METAINFO["classes"],
105
+ palette=GrassDataset.METAINFO["palette"],
106
+ )
107
+ model_name = model_path.split(os.path.sep)[-1]
108
+ models[model_name] = inference
109
+ return models
110
+
111
+
112
+ def main():
113
+ device, dataset_path = get_args()
114
+ image_paths, mask_paths = get_image_and_mask_paths(dataset_path, -1)
115
+ palette = get_palette()
116
+ models_paths = [
117
+ r"work_dirs/fcn_r50",
118
+ r"work_dirs/pspnet_r101",
119
+ r"work_dirs/deeplabv3plus_r101",
120
+ r"work_dirs/unet-s5-d16_deeplabv3",
121
+ r"work_dirs/segformer_mit-b5",
122
+ r"work_dirs/mask2former_swin_b",
123
+ r"work_dirs/dinov2_upernet",
124
+ r"work_dirs/experiment_p",
125
+ ]
126
+ models = init_all_models(models_paths, device)
127
+
128
+ model_order = [
129
+ "experiment_p",
130
+ "fcn_r50",
131
+ "pspnet_r101",
132
+ "deeplabv3plus_r101",
133
+ "unet-s5-d16_deeplabv3",
134
+ "segformer_mit-b5",
135
+ "mask2former_swin_b",
136
+ "dinov2_upernet"
137
+ ]
138
+
139
+ os.makedirs("vis_results", exist_ok=True)
140
+ for image_path, mask_path in zip(image_paths, mask_paths):
141
+ result_eval = {}
142
+ result_iou = {}
143
+ mask = Image.open(mask_path)
144
+ for model_name, inference in models.items():
145
+ predictions: np.ndarray = inference(image_path)["predictions"]
146
+ predictions = predictions.astype(np.uint8)
147
+ result_eval[model_name] = predictions
148
+ result_iou[model_name] = get_iou(predictions, np.array(mask), num_classes=5)
149
+
150
+ # 鏍规嵁iou 杩涜鎺掑簭
151
+ result_iou_sorted = sorted(result_iou.items(), key=lambda x: x[1], reverse=True)
152
+
153
+ if result_iou_sorted[0][0] != "experiment_p":
154
+ continue
155
+
156
+ plt.figure(figsize=(32, 8))
157
+ plt.subplots_adjust(wspace=0.01)
158
+ plt.subplot(1, 10, 1)
159
+ plt.imshow(Image.open(image_path))
160
+ plt.axis("off")
161
+
162
+ plt.subplot(1, 10, 2)
163
+ plt.imshow(give_color_to_mask(mask, palette=palette))
164
+ plt.axis("off")
165
+
166
+ for i, model_name in enumerate(model_order):
167
+ plt.subplot(1, 10, i + 3)
168
+ plt.imshow(give_color_to_mask(result_eval[model_name], palette))
169
+ plt.axis("off")
170
+
171
+ base_name = os.path.basename(image_path).split(".")[0]
172
+ diff_iou = result_iou_sorted[0][1] - result_iou_sorted[1][1]
173
+ plt.savefig(
174
+ f"vis_results/{diff_iou:.2f}_{base_name}.svg",
175
+ dpi=300,
176
+ bbox_inches="tight",
177
+ pad_inches=0,
178
+ )
179
+
180
+
181
+ if __name__ == "__main__":
182
+ # example usage: python tools/vis_model.py --models work_dirs --device cuda:0 --dataset_path data/grass
183
+ main()