Commit
·
12499f1
1
Parent(s):
9728e2b
--image_weights bug fix (#1524)
Browse files- utils/datasets.py +8 -6
utils/datasets.py
CHANGED
@@ -72,12 +72,14 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|
72 |
batch_size = min(batch_size, len(dataset))
|
73 |
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
74 |
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
81 |
return dataloader, dataset
|
82 |
|
83 |
|
|
|
72 |
batch_size = min(batch_size, len(dataset))
|
73 |
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
74 |
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
75 |
+
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
|
76 |
+
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
|
77 |
+
dataloader = loader(dataset,
|
78 |
+
batch_size=batch_size,
|
79 |
+
num_workers=nw,
|
80 |
+
sampler=sampler,
|
81 |
+
pin_memory=True,
|
82 |
+
collate_fn=LoadImagesAndLabels.collate_fn)
|
83 |
return dataloader, dataset
|
84 |
|
85 |
|