|
from glob import glob |
|
from typing import Tuple,List |
|
import os |
|
import argparse |
|
import json |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
|
|
def get_args()->Tuple[str, str]: |
|
""" |
|
Return: |
|
--dataset_dir: dataset dir. |
|
--save_dir: save dir. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--dataset_dir', type=str, default='data/grass') |
|
parser.add_argument('--save_dir', type=str, default='dataset_num_analysis.png') |
|
args = parser.parse_args() |
|
return args.dataset_dir, args.save_dir |
|
|
|
def get_mask_files(dataset_dir: str)->List[str]: |
|
""" |
|
get mask files from dataset dir. |
|
Args: |
|
dataset_dir: dataset dir. |
|
Return: |
|
mask_filenames: list of mask filenames. |
|
""" |
|
mask_filenames = glob(os.path.join(dataset_dir, "ann_dir", "*", "*.png")) |
|
return mask_filenames |
|
|
|
def main(): |
|
dataset_dir, save_dir = get_args() |
|
mask_filenames = get_mask_files(dataset_dir) |
|
statistic = {} |
|
for mask_filename in mask_filenames: |
|
mask = np.array(Image.open(mask_filename)) |
|
classes = np.unique(mask) |
|
for class_ in classes: |
|
class_ = int(class_) |
|
if class_ not in statistic: |
|
statistic[class_] = 0 |
|
statistic[(class_)] += int(np.sum(mask == class_)) |
|
|
|
classes = list(statistic.keys()) |
|
clasees_num = list(statistic.values()) |
|
|
|
plt.title("Dataset Analysis") |
|
bars = plt.bar(classes, clasees_num) |
|
for bar in bars: |
|
height = bar.get_height() |
|
plt.text(bar.get_x() + bar.get_width() / 2, height + 5, str(height), ha='center', va='bottom') |
|
plt.savefig(save_dir,dpi=300) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |