File size: 3,335 Bytes
918db92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from glob import glob
import argparse
import os
from typing import Tuple, List
from PIL import Image
from rich.progress import track
from vegseg.datasets import GrassDataset
def get_args() -> Tuple[str, str, int]:
"""
get args
return:
--dataset_path: dataset path.
--output_path: output path for saving.
--num: num of image to show. -1 means all.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="data/grass")
parser.add_argument("--output_path", type=str, default="all_dataset.png")
parser.add_argument("--num", default=-1, type=int, help="num of image to show")
args = parser.parse_args()
return args.dataset_path, args.output_path, args.num
def get_image_and_mask_paths(
dataset_path: str, num: int
) -> Tuple[List[str], List[str]]:
"""
get image and mask paths from dataset path.
return:
image_paths: list of image paths.
mask_paths: list of mask paths.
"""
image_paths = glob(os.path.join(dataset_path, "img_dir", "*", "*.tif"))
if num != -1:
image_paths = image_paths[:num]
mask_paths = [
filename.replace("tif", "png").replace("img_dir", "ann_dir")
for filename in image_paths
]
return image_paths, mask_paths
def get_palette() -> List[int]:
"""
get palette of dataset.
return:
palette: list of palette.
"""
palette = []
palette_list = GrassDataset.METAINFO["palette"]
for palette_item in palette_list:
palette.extend(palette_item)
return palette
def paste_image_mask(image_path: str, mask_path: str) -> Image.Image:
"""
paste image and mask together
Args:
image_path (str): path to image.
mask_path (str): path to mask.
return:
image_mask: image with mask,is Image.
"""
image = Image.open(image_path)
mask = Image.open(mask_path).convert("P")
palette = get_palette()
mask.putpalette(palette)
mask = mask.convert("RGB")
image_mask = Image.new("RGB", (image.width * 2, image.height))
image_mask.paste(image, (0, 0))
image_mask.paste(mask, (image.width, 0))
return image_mask
def paste_all_images(all_images: List[Image.Image], output_path: str) -> None:
"""
paste all images together and save it.
Args:
all_images (List[Image.Image]): list of image.
output_path (str): path to save.
Return:
None
"""
widths = [image.width for image in all_images]
heights = [image.height for image in all_images]
width = max(widths)
height = sum(heights)
all_image = Image.new("RGB", (width, height))
for i, image in enumerate(all_images):
all_image.paste(image, (0, sum(heights[:i])))
all_image.save(output_path)
def main():
dataset_path, output_path, num = get_args()
image_paths, mask_paths = get_image_and_mask_paths(dataset_path, num)
all_images = []
for image_path, mask_path in zip(image_paths, mask_paths):
image_mask = paste_image_mask(image_path, mask_path)
all_images.append(image_mask)
paste_all_images(all_images, output_path)
if __name__ == "__main__":
# example usage: python tools/dataset_tools/dataset_show.py --dataset_path data/grass --output_path all_dataset.png
main()
|