Werner Duvaud glenn-jocher pre-commit-ci[bot] commited on
Commit
09d1703
·
unverified ·
1 Parent(s): 7473f0f

Default DataLoader `shuffle=True` for training (#5623)

Browse files

* Fix shuffle DataLoader argument

* Add shuffle argument

* Disable shuffle when rect

* Cleanup, add rect warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup2

* Cleanup3

Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (2) hide show
  1. train.py +1 -1
  2. utils/datasets.py +21 -20
train.py CHANGED
@@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
212
  train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
213
  hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
214
  workers=workers, image_weights=opt.image_weights, quad=opt.quad,
215
- prefix=colorstr('train: '))
216
  mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
217
  nb = len(train_loader) # number of batches
218
  assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
 
212
  train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
213
  hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
214
  workers=workers, image_weights=opt.image_weights, quad=opt.quad,
215
+ prefix=colorstr('train: '), shuffle=True)
216
  mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
217
  nb = len(train_loader) # number of batches
218
  assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
utils/datasets.py CHANGED
@@ -22,7 +22,7 @@ import torch
22
  import torch.nn.functional as F
23
  import yaml
24
  from PIL import ExifTags, Image, ImageOps
25
- from torch.utils.data import Dataset
26
  from tqdm import tqdm
27
 
28
  from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
@@ -93,13 +93,15 @@ def exif_transpose(image):
93
 
94
 
95
  def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
96
- rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
97
- # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
98
- with torch_distributed_zero_first(rank):
 
 
99
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
100
- augment=augment, # augment images
101
- hyp=hyp, # augmentation hyperparameters
102
- rect=rect, # rectangular training
103
  cache_images=cache,
104
  single_cls=single_cls,
105
  stride=int(stride),
@@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
109
 
110
  batch_size = min(batch_size, len(dataset))
111
  nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
112
- sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
113
- loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
114
- # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
115
- dataloader = loader(dataset,
116
- batch_size=batch_size,
117
- num_workers=nw,
118
- sampler=sampler,
119
- pin_memory=True,
120
- collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
121
- return dataloader, dataset
122
-
123
-
124
- class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
125
  """ Dataloader that reuses workers
126
 
127
  Uses same syntax as vanilla DataLoader
 
22
  import torch.nn.functional as F
23
  import yaml
24
  from PIL import ExifTags, Image, ImageOps
25
+ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
26
  from tqdm import tqdm
27
 
28
  from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
 
93
 
94
 
95
  def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
96
+ rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
97
+ if rect and shuffle:
98
+ LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
99
+ shuffle = False
100
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
101
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
102
+ augment=augment, # augmentation
103
+ hyp=hyp, # hyperparameters
104
+ rect=rect, # rectangular batches
105
  cache_images=cache,
106
  single_cls=single_cls,
107
  stride=int(stride),
 
111
 
112
  batch_size = min(batch_size, len(dataset))
113
  nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
114
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
115
+ loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
116
+ return loader(dataset,
117
+ batch_size=batch_size,
118
+ shuffle=shuffle and sampler is None,
119
+ num_workers=nw,
120
+ sampler=sampler,
121
+ pin_memory=True,
122
+ collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
123
+
124
+
125
+ class InfiniteDataLoader(dataloader.DataLoader):
 
126
  """ Dataloader that reuses workers
127
 
128
  Uses same syntax as vanilla DataLoader