Nanobit commited on
Commit
1e15aad
·
unverified ·
1 Parent(s): 69ff781

Add InfiniteDataLoader class (#876)

Browse files

* Add InfiniteDataLoader

Only initializes at first epoch. Saves time.

* Moved class to a better location

Files changed (1) hide show
  1. utils/datasets.py +42 -6
utils/datasets.py CHANGED
@@ -63,15 +63,51 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
63
  batch_size = min(batch_size, len(dataset))
64
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
65
  train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
66
- dataloader = torch.utils.data.DataLoader(dataset,
67
- batch_size=batch_size,
68
- num_workers=nw,
69
- sampler=train_sampler,
70
- pin_memory=True,
71
- collate_fn=LoadImagesAndLabels.collate_fn)
72
  return dataloader, dataset
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  class LoadImages: # for inference
76
  def __init__(self, path, img_size=640):
77
  p = str(Path(path)) # os-agnostic
 
63
  batch_size = min(batch_size, len(dataset))
64
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
65
  train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
66
+ dataloader = InfiniteDataLoader (dataset,
67
+ batch_size=batch_size,
68
+ num_workers=nw,
69
+ sampler=train_sampler,
70
+ pin_memory=True,
71
+ collate_fn=LoadImagesAndLabels.collate_fn)
72
  return dataloader, dataset
73
 
74
 
75
+ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
76
+ '''
77
+ Dataloader that reuses workers.
78
+
79
+ Uses same syntax as vanilla DataLoader.
80
+ '''
81
+
82
+ def __init__(self, *args, **kwargs):
83
+ super().__init__(*args, **kwargs)
84
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
85
+ self.iterator = super().__iter__()
86
+
87
+ def __len__(self):
88
+ return len(self.batch_sampler.sampler)
89
+
90
+ def __iter__(self):
91
+ for i in range(len(self)):
92
+ yield next(self.iterator)
93
+
94
+
95
+ class _RepeatSampler(object):
96
+ '''
97
+ Sampler that repeats forever.
98
+
99
+ Args:
100
+ sampler (Sampler)
101
+ '''
102
+
103
+ def __init__(self, sampler):
104
+ self.sampler = sampler
105
+
106
+ def __iter__(self):
107
+ while True:
108
+ yield from iter(self.sampler)
109
+
110
+
111
  class LoadImages: # for inference
112
  def __init__(self, path, img_size=640):
113
  p = str(Path(path)) # os-agnostic