Default DataLoader `shuffle=True` for training (#5623)
Browse files* Fix shuffle DataLoader argument
* Add shuffle argument
* Disable shuffle when rect
* Cleanup, add rect warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Cleanup2
* Cleanup3
Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- train.py +1 -1
- utils/datasets.py +21 -20
train.py
CHANGED
@@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
212 |
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
213 |
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
|
214 |
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
|
215 |
-
prefix=colorstr('train: '))
|
216 |
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|
217 |
nb = len(train_loader) # number of batches
|
218 |
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
|
|
|
212 |
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
213 |
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
|
214 |
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
|
215 |
+
prefix=colorstr('train: '), shuffle=True)
|
216 |
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|
217 |
nb = len(train_loader) # number of batches
|
218 |
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
|
utils/datasets.py
CHANGED
@@ -22,7 +22,7 @@ import torch
|
|
22 |
import torch.nn.functional as F
|
23 |
import yaml
|
24 |
from PIL import ExifTags, Image, ImageOps
|
25 |
-
from torch.utils.data import Dataset
|
26 |
from tqdm import tqdm
|
27 |
|
28 |
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
@@ -93,13 +93,15 @@ def exif_transpose(image):
|
|
93 |
|
94 |
|
95 |
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
96 |
-
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
|
97 |
-
|
98 |
-
|
|
|
|
|
99 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
100 |
-
augment=augment, #
|
101 |
-
hyp=hyp, #
|
102 |
-
rect=rect, # rectangular
|
103 |
cache_images=cache,
|
104 |
single_cls=single_cls,
|
105 |
stride=int(stride),
|
@@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
|
|
109 |
|
110 |
batch_size = min(batch_size, len(dataset))
|
111 |
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
112 |
-
sampler =
|
113 |
-
loader =
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
125 |
""" Dataloader that reuses workers
|
126 |
|
127 |
Uses same syntax as vanilla DataLoader
|
|
|
22 |
import torch.nn.functional as F
|
23 |
import yaml
|
24 |
from PIL import ExifTags, Image, ImageOps
|
25 |
+
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
26 |
from tqdm import tqdm
|
27 |
|
28 |
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
|
|
93 |
|
94 |
|
95 |
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
96 |
+
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
|
97 |
+
if rect and shuffle:
|
98 |
+
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
|
99 |
+
shuffle = False
|
100 |
+
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
101 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
102 |
+
augment=augment, # augmentation
|
103 |
+
hyp=hyp, # hyperparameters
|
104 |
+
rect=rect, # rectangular batches
|
105 |
cache_images=cache,
|
106 |
single_cls=single_cls,
|
107 |
stride=int(stride),
|
|
|
111 |
|
112 |
batch_size = min(batch_size, len(dataset))
|
113 |
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
114 |
+
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
115 |
+
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
116 |
+
return loader(dataset,
|
117 |
+
batch_size=batch_size,
|
118 |
+
shuffle=shuffle and sampler is None,
|
119 |
+
num_workers=nw,
|
120 |
+
sampler=sampler,
|
121 |
+
pin_memory=True,
|
122 |
+
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
|
123 |
+
|
124 |
+
|
125 |
+
class InfiniteDataLoader(dataloader.DataLoader):
|
|
|
126 |
""" Dataloader that reuses workers
|
127 |
|
128 |
Uses same syntax as vanilla DataLoader
|