ktda-models / tools /vis_cloud.py
XavierJiezou's picture
Add files using upload-large-folder tool
918db92 verified
from mmseg.apis import MMSegInferencer
from glob import glob
from vegseg.datasets import L8BIOMEDataset
import numpy as np
from typing import List
import os
from PIL import Image
from vegseg import models
def get_palette() -> List[int]:
"""
get palette of dataset.
return:
palette: list of palette.
"""
palette = []
palette_list = L8BIOMEDataset.METAINFO["palette"]
for palette_item in palette_list:
palette.extend(palette_item)
return palette
def give_color_to_mask(
mask: Image.Image | np.ndarray, palette: List[int]
) -> Image.Image:
"""
give color to mask.
return:
color_mask: color mask.
"""
color_mask = Image.fromarray(mask).convert("P")
color_mask.putpalette(palette)
return color_mask
def main():
config_path = "work_dirs/experiment_p_l8/experiment_p_l8.py"
weight_path = "work_dirs/experiment_p_l8/best_mIoU_iter_20000.pth"
inference = MMSegInferencer(
model=config_path,
weights=weight_path,
device="cuda:1",
classes=L8BIOMEDataset.METAINFO["classes"],
palette=L8BIOMEDataset.METAINFO["palette"],
)
images = glob("data/vis/input/*.png")
palette = get_palette()
predictions = inference.__call__(images,batch_size=16)["predictions"]
for image_path, prediction in zip(images, predictions):
filename = os.path.basename(image_path)
filename = os.path.join("data/vis/ktda",filename)
prediction = prediction.astype(np.uint8)
color_mask = give_color_to_mask(prediction, palette=palette)
color_mask.save(filename)
if __name__ == "__main__":
main()