XavierJiezou commited on
Commit
28c6db0
·
verified ·
1 Parent(s): 5290f34

Create code/vis_cloud.py

Browse files
Files changed (1) hide show
  1. visualization/code/vis_cloud.py +57 -0
visualization/code/vis_cloud.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmseg.apis import MMSegInferencer
2
+ from glob import glob
3
+ from vegseg.datasets import L8BIOMEDataset
4
+ import numpy as np
5
+ from typing import List
6
+ import os
7
+ from PIL import Image
8
+ from vegseg import models
9
+
10
+ def get_palette() -> List[int]:
11
+ """
12
+ get palette of dataset.
13
+ return:
14
+ palette: list of palette.
15
+ """
16
+ palette = []
17
+ palette_list = L8BIOMEDataset.METAINFO["palette"]
18
+ for palette_item in palette_list:
19
+ palette.extend(palette_item)
20
+ return palette
21
+
22
+
23
+ def give_color_to_mask(
24
+ mask: Image.Image | np.ndarray, palette: List[int]
25
+ ) -> Image.Image:
26
+ """
27
+ give color to mask.
28
+ return:
29
+ color_mask: color mask.
30
+ """
31
+ color_mask = Image.fromarray(mask).convert("P")
32
+ color_mask.putpalette(palette)
33
+ return color_mask
34
+
35
+
36
+ def main():
37
+ config_path = "work_dirs/experiment_p_l8/experiment_p_l8.py"
38
+ weight_path = "work_dirs/experiment_p_l8/best_mIoU_iter_20000.pth"
39
+ inference = MMSegInferencer(
40
+ model=config_path,
41
+ weights=weight_path,
42
+ device="cuda:1",
43
+ classes=L8BIOMEDataset.METAINFO["classes"],
44
+ palette=L8BIOMEDataset.METAINFO["palette"],
45
+ )
46
+ images = glob("data/vis/input/*.png")
47
+ palette = get_palette()
48
+ predictions = inference.__call__(images,batch_size=16)["predictions"]
49
+ for image_path, prediction in zip(images, predictions):
50
+ filename = os.path.basename(image_path)
51
+ filename = os.path.join("data/vis/ktda",filename)
52
+ prediction = prediction.astype(np.uint8)
53
+ color_mask = give_color_to_mask(prediction, palette=palette)
54
+ color_mask.save(filename)
55
+
56
+ if __name__ == "__main__":
57
+ main()