File size: 3,335 Bytes
918db92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from glob import glob
import argparse
import os
from typing import Tuple, List
from PIL import Image
from rich.progress import track
from vegseg.datasets import GrassDataset


def get_args() -> Tuple[str, str, int]:
    """
    get args
    return:
        --dataset_path: dataset path.
        --output_path: output path for saving.
        --num: num of image to show. -1 means all.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="data/grass")
    parser.add_argument("--output_path", type=str, default="all_dataset.png")
    parser.add_argument("--num", default=-1, type=int, help="num of image to show")
    args = parser.parse_args()
    return args.dataset_path, args.output_path, args.num


def get_image_and_mask_paths(
    dataset_path: str, num: int
) -> Tuple[List[str], List[str]]:
    """
    get image and mask paths from dataset path.
    return:
        image_paths: list of image paths.
        mask_paths: list of mask paths.
    """
    image_paths = glob(os.path.join(dataset_path, "img_dir", "*", "*.tif"))
    if num != -1:
        image_paths = image_paths[:num]
    mask_paths = [
        filename.replace("tif", "png").replace("img_dir", "ann_dir")
        for filename in image_paths
    ]
    return image_paths, mask_paths


def get_palette() -> List[int]:
    """
    get palette of dataset.
    return:
        palette: list of palette.
    """
    palette = []
    palette_list = GrassDataset.METAINFO["palette"]
    for palette_item in palette_list:
        palette.extend(palette_item)
    return palette


def paste_image_mask(image_path: str, mask_path: str) -> Image.Image:
    """
    paste image and mask together
    Args:
        image_path (str): path to image.
        mask_path (str): path to mask.
    return:
        image_mask: image with mask,is Image.
    """
    image = Image.open(image_path)
    mask = Image.open(mask_path).convert("P")
    palette = get_palette()
    mask.putpalette(palette)
    mask = mask.convert("RGB")
    image_mask = Image.new("RGB", (image.width * 2, image.height))
    image_mask.paste(image, (0, 0))
    image_mask.paste(mask, (image.width, 0))
    return image_mask


def paste_all_images(all_images: List[Image.Image], output_path: str) -> None:
    """
    paste all images together and save it.
    Args:
        all_images (List[Image.Image]): list of image.
        output_path (str): path to save.
    Return:
        None
    """
    widths = [image.width for image in all_images]
    heights = [image.height for image in all_images]
    width = max(widths)
    height = sum(heights)
    all_image = Image.new("RGB", (width, height))
    for i, image in enumerate(all_images):
        all_image.paste(image, (0, sum(heights[:i])))
    all_image.save(output_path)


def main():
    dataset_path, output_path, num = get_args()
    image_paths, mask_paths = get_image_and_mask_paths(dataset_path, num)
    all_images = []
    for image_path, mask_path in zip(image_paths, mask_paths):
        image_mask = paste_image_mask(image_path, mask_path)
        all_images.append(image_mask)
    paste_all_images(all_images, output_path)


if __name__ == "__main__":
    # example usage: python tools/dataset_tools/dataset_show.py --dataset_path data/grass --output_path all_dataset.png
    main()