XavierJiezou commited on
Commit
8300cd5
verified
1 Parent(s): 726f933

Create vis_model.py

Browse files
Files changed (1) hide show
  1. visualization/code/vis_model.py +160 -0
visualization/code/vis_model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import argparse
3
+ import os
4
+ from typing import Tuple, List
5
+ import numpy as np
6
+ from mmeval import MeanIoU
7
+ from PIL import Image
8
+ from matplotlib import pyplot as plt
9
+ from mmseg.apis import MMSegInferencer
10
+ from vegseg.datasets import GrassDataset
11
+
12
+
13
+ def get_iou(pred: np.ndarray, gt: np.ndarray, num_classes=2):
14
+ pred = pred[np.newaxis]
15
+ gt = gt[np.newaxis]
16
+ miou = MeanIoU(num_classes=num_classes)
17
+ result = miou(pred, gt)
18
+ return result["mIoU"] * 100
19
+
20
+
21
+ def get_args() -> Tuple[str, str, int]:
22
+ """
23
+ get args
24
+ return:
25
+ --models: all_models path.
26
+ --device: device to use.
27
+ --dataset_path: dataset path.
28
+ --output_path: output path for saving.
29
+ """
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--models", type=str, default="work_dirs")
32
+ parser.add_argument("--device", type=str, default="cuda:0")
33
+ parser.add_argument("--dataset_path", type=str, default="data/grass")
34
+ args = parser.parse_args()
35
+ return args.models, args.device, args.dataset_path
36
+
37
+
38
+ def give_color_to_mask(
39
+ mask: Image.Image | np.ndarray, palette: List[int]
40
+ ) -> Image.Image:
41
+ """
42
+ Args:
43
+ mask: mask to color, numpy array or PIL Image.
44
+ palette: palette of dataset.
45
+ return:
46
+ mask: mask with color.
47
+ """
48
+ if isinstance(mask, np.ndarray):
49
+ mask = Image.fromarray(mask)
50
+ mask = mask.convert("P")
51
+ mask.putpalette(palette)
52
+ return mask
53
+
54
+
55
+ def get_image_and_mask_paths(
56
+ dataset_path: str, num: int
57
+ ) -> Tuple[List[str], List[str]]:
58
+ """
59
+ get image and mask paths from dataset path.
60
+ return:
61
+ image_paths: list of image paths.
62
+ mask_paths: list of mask paths.
63
+ """
64
+ image_paths = glob(os.path.join(dataset_path, "img_dir", "val", "*.tif"))
65
+ if num != -1:
66
+ image_paths = image_paths[:num]
67
+ mask_paths = [
68
+ filename.replace("tif", "png").replace("img_dir", "ann_dir")
69
+ for filename in image_paths
70
+ ]
71
+ return image_paths, mask_paths
72
+
73
+
74
+ def get_palette() -> List[int]:
75
+ """
76
+ get palette of dataset.
77
+ return:
78
+ palette: list of palette.
79
+ """
80
+ palette = []
81
+ palette_list = GrassDataset.METAINFO["palette"]
82
+ for palette_item in palette_list:
83
+ palette.extend(palette_item)
84
+ return palette
85
+
86
+
87
+ def init_all_models(models_path: str, device: str):
88
+ """
89
+ init all models
90
+ Args:
91
+ models_path (str): path to all models.
92
+ device (str): device to use.
93
+ Return:
94
+ models (dict): dict of models.
95
+ """
96
+ models = {}
97
+ all_models = os.listdir(models_path)
98
+ for model_path in all_models:
99
+ model_name = model_path
100
+ model_path = os.path.join(models_path, model_path)
101
+ config_path = glob(os.path.join(model_path, "*.py"))[0]
102
+ weight_path = glob(os.path.join(model_path, "best_mIoU_iter_*.pth"))[0]
103
+ inference = MMSegInferencer(
104
+ config_path,
105
+ weight_path,
106
+ device=device,
107
+ classes=GrassDataset.METAINFO["classes"],
108
+ palette=GrassDataset.METAINFO["palette"],
109
+ )
110
+ models[model_name] = inference
111
+ return models
112
+
113
+
114
+ def main():
115
+ models_path, device, dataset_path = get_args()
116
+ image_paths, mask_paths = get_image_and_mask_paths(dataset_path, -1)
117
+ palette = get_palette()
118
+ models = init_all_models(models_path, device)
119
+ os.makedirs("vis_results", exist_ok=True)
120
+ for image_path, mask_path in zip(image_paths, mask_paths):
121
+ result_eval = {}
122
+ result_iou = {}
123
+ mask = Image.open(mask_path)
124
+ for model_name, inference in models.items():
125
+ predictions: np.ndarray = inference(image_path)["predictions"]
126
+ predictions = predictions.astype(np.uint8)
127
+ result_eval[model_name] = predictions
128
+ result_iou[model_name] = get_iou(predictions, np.array(mask), num_classes=5)
129
+
130
+ # 鏍规嵁iou 杩涜鎺掑簭
131
+ result_iou_sorted = sorted(result_iou.items(), key=lambda x: x[1], reverse=True)
132
+ plt.figure(figsize=(36, 3))
133
+ plt.subplot(1, len(models) + 2, 1)
134
+ plt.imshow(Image.open(image_path))
135
+ plt.axis("off")
136
+ plt.title("Input")
137
+
138
+ plt.subplot(1, len(models) + 2, 2)
139
+ plt.imshow(give_color_to_mask(mask, palette=palette))
140
+ plt.axis("off")
141
+ plt.title("Label")
142
+
143
+ for i, (model_name, _) in enumerate(result_iou_sorted):
144
+ plt.subplot(1, len(models) + 2, i + 3)
145
+ plt.imshow(give_color_to_mask(result_eval[model_name], palette))
146
+ plt.axis("off")
147
+ plt.title(f"{model_name}: {result_iou[model_name]:.2f}")
148
+
149
+ base_name = os.path.basename(image_path).split(".")[0]
150
+ plt.savefig(
151
+ f"vis_results/{base_name}.png",
152
+ dpi=300,
153
+ bbox_inches="tight",
154
+ pad_inches=0,
155
+ )
156
+
157
+
158
+ if __name__ == "__main__":
159
+ # example usage: python tools/vis_model.py --models work_dirs --device cuda:0 --dataset_path data/grass
160
+ main()