File size: 1,724 Bytes
31c0288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from glob import glob
from typing import Tuple,List
import os
import argparse
import json
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

def get_args()->Tuple[str, str]:
    """
    Return:
        --dataset_dir: dataset dir.
        --save_dir: save dir.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_dir', type=str, default='data/grass')
    parser.add_argument('--save_dir', type=str, default='dataset_num_analysis.png')
    args = parser.parse_args()
    return args.dataset_dir, args.save_dir

def get_mask_files(dataset_dir: str)->List[str]:
    """
    get mask files from dataset dir.
    Args:
        dataset_dir: dataset dir.
    Return:
        mask_filenames: list of mask filenames.
    """
    mask_filenames = glob(os.path.join(dataset_dir, "ann_dir", "*", "*.png"))
    return mask_filenames

def main():
    dataset_dir, save_dir = get_args()
    mask_filenames = get_mask_files(dataset_dir)
    statistic = {}
    for mask_filename in mask_filenames:
        mask = np.array(Image.open(mask_filename))
        classes = np.unique(mask)
        for class_ in classes:
            class_ = int(class_)
            if class_ not in statistic:
                statistic[class_] = 0
            statistic[(class_)] += int(np.sum(mask == class_))
    
    classes = list(statistic.keys())
    clasees_num = list(statistic.values())

    plt.title("Dataset Analysis")
    bars = plt.bar(classes, clasees_num)
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width() / 2, height + 5, str(height), ha='center', va='bottom')
    plt.savefig(save_dir,dpi=300)
    

if __name__ == "__main__":
    main()