glenn-jocher commited on
Commit
b6fdd2e
·
unverified ·
1 Parent(s): ac8691e

Create `dataset_stats()` for HUB

Browse files
Files changed (1) hide show
  1. utils/datasets.py +34 -2
utils/datasets.py CHANGED
@@ -17,12 +17,13 @@ import cv2
17
  import numpy as np
18
  import torch
19
  import torch.nn.functional as F
 
20
  from PIL import Image, ExifTags
21
  from torch.utils.data import Dataset
22
  from tqdm import tqdm
23
 
24
- from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
25
- resample_segments, clean_str
26
  from utils.torch_utils import torch_distributed_zero_first
27
 
28
  # Parameters
@@ -1083,3 +1084,34 @@ def verify_image_label(params):
1083
  nc = 1
1084
  logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
1085
  return [None] * 4 + [nm, nf, ne, nc]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  import numpy as np
18
  import torch
19
  import torch.nn.functional as F
20
+ import yaml
21
  from PIL import Image, ExifTags
22
  from torch.utils.data import Dataset
23
  from tqdm import tqdm
24
 
25
+ from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \
26
+ segment2box, segments2boxes, resample_segments, clean_str
27
  from utils.torch_utils import torch_distributed_zero_first
28
 
29
  # Parameters
 
1084
  nc = 1
1085
  logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
1086
  return [None] * 4 + [nm, nf, ne, nc]
1087
+
1088
+
1089
+ def dataset_stats(path='data/coco128.yaml', verbose=False):
1090
+ """ Return dataset statistics dictionary with images and instances counts per split per class
1091
+ Usage: from utils.datasets import *; dataset_stats('data/coco128.yaml')
1092
+ Arguments
1093
+ path: Path to data.yaml
1094
+ verbose: Print stats dictionary
1095
+ """
1096
+ path = check_file(Path(path))
1097
+ with open(path) as f:
1098
+ data = yaml.safe_load(f) # data dict
1099
+ check_dataset(data) # download dataset if missing
1100
+
1101
+ nc = data['nc'] # number of classes
1102
+ stats = {'nc': nc, 'names': data['names']} # statistics dictionary
1103
+ for split in 'train', 'val', 'test':
1104
+ if split not in data:
1105
+ stats[split] = None # i.e. no test set
1106
+ continue
1107
+ x = []
1108
+ dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
1109
+ for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
1110
+ x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
1111
+ x = np.array(x) # shape(128x80)
1112
+ stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
1113
+ 'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
1114
+ 'per_class': (x > 0).sum(0).tolist()}}
1115
+ if verbose:
1116
+ print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
1117
+ return stats