from mmseg.apis import MMSegInferencer from glob import glob from vegseg.datasets import L8BIOMEDataset import numpy as np from typing import List import os from PIL import Image from vegseg import models def get_palette() -> List[int]: """ get palette of dataset. return: palette: list of palette. """ palette = [] palette_list = L8BIOMEDataset.METAINFO["palette"] for palette_item in palette_list: palette.extend(palette_item) return palette def give_color_to_mask( mask: Image.Image | np.ndarray, palette: List[int] ) -> Image.Image: """ give color to mask. return: color_mask: color mask. """ color_mask = Image.fromarray(mask).convert("P") color_mask.putpalette(palette) return color_mask def main(): config_path = "work_dirs/experiment_p_l8/experiment_p_l8.py" weight_path = "work_dirs/experiment_p_l8/best_mIoU_iter_20000.pth" inference = MMSegInferencer( model=config_path, weights=weight_path, device="cuda:1", classes=L8BIOMEDataset.METAINFO["classes"], palette=L8BIOMEDataset.METAINFO["palette"], ) images = glob("data/vis/input/*.png") palette = get_palette() predictions = inference.__call__(images,batch_size=16)["predictions"] for image_path, prediction in zip(images, predictions): filename = os.path.basename(image_path) filename = os.path.join("data/vis/ktda",filename) prediction = prediction.astype(np.uint8) color_mask = give_color_to_mask(prediction, palette=palette) color_mask.save(filename) if __name__ == "__main__": main()