|
from glob import glob |
|
import argparse |
|
import os |
|
from typing import Tuple, List |
|
from PIL import Image |
|
from rich.progress import track |
|
from vegseg.datasets import GrassDataset |
|
|
|
|
|
def get_args() -> Tuple[str, str, int]: |
|
""" |
|
get args |
|
return: |
|
--dataset_path: dataset path. |
|
--output_path: output path for saving. |
|
--num: num of image to show. -1 means all. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--dataset_path", type=str, default="data/grass") |
|
parser.add_argument("--output_path", type=str, default="all_dataset.png") |
|
parser.add_argument("--num", default=-1, type=int, help="num of image to show") |
|
args = parser.parse_args() |
|
return args.dataset_path, args.output_path, args.num |
|
|
|
|
|
def get_image_and_mask_paths( |
|
dataset_path: str, num: int |
|
) -> Tuple[List[str], List[str]]: |
|
""" |
|
get image and mask paths from dataset path. |
|
return: |
|
image_paths: list of image paths. |
|
mask_paths: list of mask paths. |
|
""" |
|
image_paths = glob(os.path.join(dataset_path, "img_dir", "*", "*.tif")) |
|
if num != -1: |
|
image_paths = image_paths[:num] |
|
mask_paths = [ |
|
filename.replace("tif", "png").replace("img_dir", "ann_dir") |
|
for filename in image_paths |
|
] |
|
return image_paths, mask_paths |
|
|
|
|
|
def get_palette() -> List[int]: |
|
""" |
|
get palette of dataset. |
|
return: |
|
palette: list of palette. |
|
""" |
|
palette = [] |
|
palette_list = GrassDataset.METAINFO["palette"] |
|
for palette_item in palette_list: |
|
palette.extend(palette_item) |
|
return palette |
|
|
|
|
|
def paste_image_mask(image_path: str, mask_path: str) -> Image.Image: |
|
""" |
|
paste image and mask together |
|
Args: |
|
image_path (str): path to image. |
|
mask_path (str): path to mask. |
|
return: |
|
image_mask: image with mask,is Image. |
|
""" |
|
image = Image.open(image_path) |
|
mask = Image.open(mask_path).convert("P") |
|
palette = get_palette() |
|
mask.putpalette(palette) |
|
mask = mask.convert("RGB") |
|
image_mask = Image.new("RGB", (image.width * 2, image.height)) |
|
image_mask.paste(image, (0, 0)) |
|
image_mask.paste(mask, (image.width, 0)) |
|
return image_mask |
|
|
|
|
|
def paste_all_images(all_images: List[Image.Image], output_path: str) -> None: |
|
""" |
|
paste all images together and save it. |
|
Args: |
|
all_images (List[Image.Image]): list of image. |
|
output_path (str): path to save. |
|
Return: |
|
None |
|
""" |
|
widths = [image.width for image in all_images] |
|
heights = [image.height for image in all_images] |
|
width = max(widths) |
|
height = sum(heights) |
|
all_image = Image.new("RGB", (width, height)) |
|
for i, image in enumerate(all_images): |
|
all_image.paste(image, (0, sum(heights[:i]))) |
|
all_image.save(output_path) |
|
|
|
|
|
def main(): |
|
dataset_path, output_path, num = get_args() |
|
image_paths, mask_paths = get_image_and_mask_paths(dataset_path, num) |
|
all_images = [] |
|
for image_path, mask_path in zip(image_paths, mask_paths): |
|
image_mask = paste_image_mask(image_path, mask_path) |
|
all_images.append(image_mask) |
|
paste_all_images(all_images, output_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|