Create `dataset_stats()` for HUB
Browse files- utils/datasets.py +34 -2
utils/datasets.py
CHANGED
@@ -17,12 +17,13 @@ import cv2
|
|
17 |
import numpy as np
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
|
|
20 |
from PIL import Image, ExifTags
|
21 |
from torch.utils.data import Dataset
|
22 |
from tqdm import tqdm
|
23 |
|
24 |
-
from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy,
|
25 |
-
resample_segments, clean_str
|
26 |
from utils.torch_utils import torch_distributed_zero_first
|
27 |
|
28 |
# Parameters
|
@@ -1083,3 +1084,34 @@ def verify_image_label(params):
|
|
1083 |
nc = 1
|
1084 |
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
|
1085 |
return [None] * 4 + [nm, nf, ne, nc]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
import numpy as np
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
20 |
+
import yaml
|
21 |
from PIL import Image, ExifTags
|
22 |
from torch.utils.data import Dataset
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
+
from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \
|
26 |
+
segment2box, segments2boxes, resample_segments, clean_str
|
27 |
from utils.torch_utils import torch_distributed_zero_first
|
28 |
|
29 |
# Parameters
|
|
|
1084 |
nc = 1
|
1085 |
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
|
1086 |
return [None] * 4 + [nm, nf, ne, nc]
|
1087 |
+
|
1088 |
+
|
1089 |
+
def dataset_stats(path='data/coco128.yaml', verbose=False):
|
1090 |
+
""" Return dataset statistics dictionary with images and instances counts per split per class
|
1091 |
+
Usage: from utils.datasets import *; dataset_stats('data/coco128.yaml')
|
1092 |
+
Arguments
|
1093 |
+
path: Path to data.yaml
|
1094 |
+
verbose: Print stats dictionary
|
1095 |
+
"""
|
1096 |
+
path = check_file(Path(path))
|
1097 |
+
with open(path) as f:
|
1098 |
+
data = yaml.safe_load(f) # data dict
|
1099 |
+
check_dataset(data) # download dataset if missing
|
1100 |
+
|
1101 |
+
nc = data['nc'] # number of classes
|
1102 |
+
stats = {'nc': nc, 'names': data['names']} # statistics dictionary
|
1103 |
+
for split in 'train', 'val', 'test':
|
1104 |
+
if split not in data:
|
1105 |
+
stats[split] = None # i.e. no test set
|
1106 |
+
continue
|
1107 |
+
x = []
|
1108 |
+
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
|
1109 |
+
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
|
1110 |
+
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
|
1111 |
+
x = np.array(x) # shape(128x80)
|
1112 |
+
stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
|
1113 |
+
'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
|
1114 |
+
'per_class': (x > 0).sum(0).tolist()}}
|
1115 |
+
if verbose:
|
1116 |
+
print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
|
1117 |
+
return stats
|