from glob import glob from typing import Tuple,List import os import argparse import json from matplotlib import pyplot as plt import numpy as np from PIL import Image def get_args()->Tuple[str, str]: """ Return: --dataset_dir: dataset dir. --save_dir: save dir. """ parser = argparse.ArgumentParser() parser.add_argument('--dataset_dir', type=str, default='data/grass') parser.add_argument('--save_dir', type=str, default='dataset_num_analysis.png') args = parser.parse_args() return args.dataset_dir, args.save_dir def get_mask_files(dataset_dir: str)->List[str]: """ get mask files from dataset dir. Args: dataset_dir: dataset dir. Return: mask_filenames: list of mask filenames. """ mask_filenames = glob(os.path.join(dataset_dir, "ann_dir", "*", "*.png")) return mask_filenames def main(): dataset_dir, save_dir = get_args() mask_filenames = get_mask_files(dataset_dir) statistic = {} for mask_filename in mask_filenames: mask = np.array(Image.open(mask_filename)) classes = np.unique(mask) for class_ in classes: class_ = int(class_) if class_ not in statistic: statistic[class_] = 0 statistic[(class_)] += int(np.sum(mask == class_)) classes = list(statistic.keys()) clasees_num = list(statistic.values()) plt.title("Dataset Analysis") bars = plt.bar(classes, clasees_num) for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width() / 2, height + 5, str(height), ha='center', va='bottom') plt.savefig(save_dir,dpi=300) if __name__ == "__main__": main()