glenn-jocher commited on
Commit
22fb2b0
·
1 Parent(s): 97b5186

refactor dataloader

Browse files
Files changed (3) hide show
  1. test.py +3 -19
  2. train.py +6 -29
  3. utils/datasets.py +21 -1
test.py CHANGED
@@ -1,8 +1,6 @@
1
  import argparse
2
  import json
3
 
4
- from torch.utils.data import DataLoader
5
-
6
  from utils import google_utils
7
  from utils.datasets import *
8
  from utils.utils import *
@@ -56,30 +54,16 @@ def test(data,
56
  data = yaml.load(f, Loader=yaml.FullLoader) # model dict
57
  nc = 1 if single_cls else int(data['nc']) # number of classes
58
  iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
59
- # iouv = iouv[0].view(1) # comment for [email protected]:0.95
60
  niou = iouv.numel()
61
 
62
  # Dataloader
63
  if dataloader is None: # not training
 
64
  img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
65
  _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
66
-
67
- merge = opt.merge # use Merge NMS
68
  path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
69
- dataset = LoadImagesAndLabels(path,
70
- imgsz,
71
- batch_size,
72
- rect=True, # rectangular inference
73
- single_cls=opt.single_cls, # single class mode
74
- stride=int(max(model.stride)), # model stride
75
- pad=0.5) # padding
76
- batch_size = min(batch_size, len(dataset))
77
- nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
78
- dataloader = DataLoader(dataset,
79
- batch_size=batch_size,
80
- num_workers=nw,
81
- pin_memory=True,
82
- collate_fn=dataset.collate_fn)
83
 
84
  seen = 0
85
  names = model.names if hasattr(model, 'names') else model.module.names
 
1
  import argparse
2
  import json
3
 
 
 
4
  from utils import google_utils
5
  from utils.datasets import *
6
  from utils.utils import *
 
54
  data = yaml.load(f, Loader=yaml.FullLoader) # model dict
55
  nc = 1 if single_cls else int(data['nc']) # number of classes
56
  iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
 
57
  niou = iouv.numel()
58
 
59
  # Dataloader
60
  if dataloader is None: # not training
61
+ merge = opt.merge # use Merge NMS
62
  img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
63
  _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
 
 
64
  path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
65
+ dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt,
66
+ hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  seen = 0
69
  names = model.names if hasattr(model, 'names') else model.module.names
train.py CHANGED
@@ -155,38 +155,15 @@ def train(hyp):
155
  model = torch.nn.parallel.DistributedDataParallel(model)
156
  # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
157
 
158
- # Dataset
159
- dataset = LoadImagesAndLabels(train_path, imgsz, batch_size,
160
- augment=True,
161
- hyp=hyp, # augmentation hyperparameters
162
- rect=opt.rect, # rectangular training
163
- cache_images=opt.cache_images,
164
- single_cls=opt.single_cls,
165
- stride=gs)
166
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
167
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
168
 
169
- # Dataloader
170
- batch_size = min(batch_size, len(dataset))
171
- nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
172
- dataloader = torch.utils.data.DataLoader(dataset,
173
- batch_size=batch_size,
174
- num_workers=nw,
175
- shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
176
- pin_memory=True,
177
- collate_fn=dataset.collate_fn)
178
-
179
  # Testloader
180
- testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
181
- hyp=hyp,
182
- rect=True,
183
- cache_images=opt.cache_images,
184
- single_cls=opt.single_cls,
185
- stride=gs),
186
- batch_size=batch_size,
187
- num_workers=nw,
188
- pin_memory=True,
189
- collate_fn=dataset.collate_fn)
190
 
191
  # Model parameters
192
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@@ -218,7 +195,7 @@ def train(hyp):
218
  maps = np.zeros(nc) # mAP per class
219
  results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
220
  print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
221
- print('Using %g dataloader workers' % nw)
222
  print('Starting training for %g epochs...' % epochs)
223
  # torch.autograd.set_detect_anomaly(True)
224
  for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
 
155
  model = torch.nn.parallel.DistributedDataParallel(model)
156
  # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
157
 
158
+ # Trainloader
159
+ dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
160
+ hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
 
 
 
 
 
161
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
162
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
163
 
 
 
 
 
 
 
 
 
 
 
164
  # Testloader
165
+ testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
166
+ hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
 
 
 
 
 
 
 
 
167
 
168
  # Model parameters
169
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
 
195
  maps = np.zeros(nc) # mAP per class
196
  results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
197
  print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
198
+ print('Using %g dataloader workers' % dataloader.num_workers)
199
  print('Starting training for %g epochs...' % epochs)
200
  # torch.autograd.set_detect_anomaly(True)
201
  for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
utils/datasets.py CHANGED
@@ -41,6 +41,26 @@ def exif_size(img):
41
  return s
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class LoadImages: # for inference
45
  def __init__(self, path, img_size=416):
46
  path = str(Path(path)) # os-agnostic
@@ -712,7 +732,7 @@ def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10,
712
  area = w * h
713
  area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
714
  ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
715
- i = (w > 4) & (h > 4) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 10)
716
 
717
  targets = targets[i]
718
  targets[:, 1:5] = xy[i]
 
41
  return s
42
 
43
 
44
+ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
45
+ dataset = LoadImagesAndLabels(path, imgsz, batch_size,
46
+ augment=augment, # augment images
47
+ hyp=hyp, # augmentation hyperparameters
48
+ rect=rect, # rectangular training
49
+ cache_images=cache,
50
+ single_cls=opt.single_cls,
51
+ stride=stride,
52
+ pad=pad)
53
+
54
+ batch_size = min(batch_size, len(dataset))
55
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 0]) # number of workers
56
+ dataloader = torch.utils.data.DataLoader(dataset,
57
+ batch_size=batch_size,
58
+ num_workers=nw,
59
+ pin_memory=True,
60
+ collate_fn=LoadImagesAndLabels.collate_fn)
61
+ return dataloader, dataset
62
+
63
+
64
  class LoadImages: # for inference
65
  def __init__(self, path, img_size=416):
66
  path = str(Path(path)) # os-agnostic
 
732
  area = w * h
733
  area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
734
  ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
735
+ i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)
736
 
737
  targets = targets[i]
738
  targets[:, 1:5] = xy[i]