File size: 1,660 Bytes
918db92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from mmseg.apis import MMSegInferencer
from glob import glob
from vegseg.datasets import L8BIOMEDataset
import numpy as np
from typing import List
import os
from PIL import Image
from vegseg import models

def get_palette() -> List[int]:
    """
    get palette of dataset.
    return:
        palette: list of palette.
    """
    palette = []
    palette_list = L8BIOMEDataset.METAINFO["palette"]
    for palette_item in palette_list:
        palette.extend(palette_item)
    return palette


def give_color_to_mask(
    mask: Image.Image | np.ndarray, palette: List[int]
) -> Image.Image:
    """
    give color to mask.
    return:
        color_mask: color mask.
    """
    color_mask = Image.fromarray(mask).convert("P")
    color_mask.putpalette(palette)
    return color_mask


def main():
    config_path = "work_dirs/experiment_p_l8/experiment_p_l8.py"
    weight_path = "work_dirs/experiment_p_l8/best_mIoU_iter_20000.pth"
    inference = MMSegInferencer(
        model=config_path,
        weights=weight_path,
        device="cuda:1",
        classes=L8BIOMEDataset.METAINFO["classes"],
        palette=L8BIOMEDataset.METAINFO["palette"],
    )
    images = glob("data/vis/input/*.png")
    palette = get_palette()
    predictions = inference.__call__(images,batch_size=16)["predictions"]
    for image_path, prediction in zip(images, predictions):
        filename = os.path.basename(image_path)
        filename = os.path.join("data/vis/ktda",filename)
        prediction = prediction.astype(np.uint8)
        color_mask = give_color_to_mask(prediction, palette=palette)
        color_mask.save(filename)

if __name__ == "__main__":
    main()