from glob import glob import argparse import os from typing import Tuple, List import numpy as np from mmeval import MeanIoU from PIL import Image from matplotlib import pyplot as plt from mmseg.apis import MMSegInferencer from vegseg.datasets import GrassDataset from vegseg import models def get_args() -> Tuple[str, str, int]: """ get args return: --device: device to use. --dataset_path: dataset path. --output_path: output path for saving. """ parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda:4") parser.add_argument("--dataset_path", type=str, default="data/grass") args = parser.parse_args() return args.device, args.dataset_path def give_color_to_mask( mask: Image.Image | np.ndarray, palette: List[int] ) -> Image.Image: """ Args: mask: mask to color, numpy array or PIL Image. palette: palette of dataset. return: mask: mask with color. """ if isinstance(mask, np.ndarray): mask = Image.fromarray(mask) mask = mask.convert("P") mask.putpalette(palette) return mask def get_image_and_mask_paths( dataset_path: str, num: int ) -> Tuple[List[str], List[str]]: """ get image and mask paths from dataset path. return: image_paths: list of image paths. mask_paths: list of mask paths. """ image_paths = glob(os.path.join(dataset_path, "img_dir", "*", "*.tif")) if num != -1: image_paths = image_paths[:num] mask_paths = [ filename.replace("tif", "png").replace("img_dir", "ann_dir") for filename in image_paths ] return image_paths, mask_paths def get_palette() -> List[int]: """ get palette of dataset. return: palette: list of palette. """ palette = [] palette_list = GrassDataset.METAINFO["palette"] for palette_item in palette_list: palette.extend(palette_item) return palette def init_all_models(models_paths: List[str], device: str): """ init all models Args: models_path (str): path to all models. device (str): device to use. Return: models (dict): dict of models. """ models = {} for model_path in models_paths: print(model_path) config_path = glob(os.path.join(model_path, "*.py"))[0] weight_path = glob(os.path.join(model_path, "best_mIoU_iter_*.pth"))[0] inference = MMSegInferencer( config_path, weight_path, device=device, classes=GrassDataset.METAINFO["classes"], palette=GrassDataset.METAINFO["palette"], ) model_name = model_path.split(os.path.sep)[-1] models[model_name] = inference return models def main(): device, dataset_path = get_args() image_paths, mask_paths = get_image_and_mask_paths(dataset_path, -1) palette = get_palette() models_paths = [ r"work_dirs/fcn_r50", r"work_dirs/pspnet_r101", r"work_dirs/deeplabv3plus_r101", r"work_dirs/unet-s5-d16_deeplabv3", r"work_dirs/segformer_mit-b5", r"work_dirs/mask2former_swin_b", r"work_dirs/dinov2_upernet", r"work_dirs/experiment_p", ] models = init_all_models(models_paths, device) model_order = [ "experiment_p", "fcn_r50", "pspnet_r101", "deeplabv3plus_r101", "unet-s5-d16_deeplabv3", "segformer_mit-b5", "mask2former_swin_b", "dinov2_upernet" ] model_mapping = { "experiment_p":"ktda", "fcn_r50":"fcn", "pspnet_r101":"pspnet", "deeplabv3plus_r101":"deeplabv3plus", "unet-s5-d16_deeplabv3":"unet", "segformer_mit-b5":"segformer", "mask2former_swin_b":"mask2former", "dinov2_upernet":"dinov2" } os.makedirs("vis_results", exist_ok=True) for model_name in model_order: os.makedirs(f"data/visualization/grass/{model_name}", exist_ok=True) os.makedirs(f"data/visualization/grass/input", exist_ok=True) os.makedirs(f"data/visualization/grass/label", exist_ok=True) for image_path, mask_path in zip(image_paths, mask_paths): filename = os.path.basename(image_path) for model_name, inference in models.items(): predictions: np.ndarray = inference(image_path)["predictions"] predictions = predictions.astype(np.uint8) predictions = give_color_to_mask(predictions, palette=palette) predictions.save(f"data/visualization/grass/{model_name}/{filename}") Image.open(image_path).save(f"data/visualization/grass/input/{filename}") Image.open(mask_path).save(f"data/visualization/grass/label/{filename}") if __name__ == "__main__": # example usage: python tools/vis_model.py --models work_dirs --device cuda:0 --dataset_path data/grass main()