glenn-jocher commited on
Commit
d49c52e
·
1 Parent(s): bb8872e

_RepeatSampler outside of InfiniteDataLoader

Browse files
Files changed (1) hide show
  1. utils/datasets.py +13 -12
utils/datasets.py CHANGED
@@ -68,7 +68,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
68
  num_workers=nw,
69
  sampler=sampler,
70
  pin_memory=True,
71
- collate_fn=LoadImagesAndLabels.collate_fn)
72
  return dataloader, dataset
73
 
74
 
@@ -80,7 +80,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
80
 
81
  def __init__(self, *args, **kwargs):
82
  super().__init__(*args, **kwargs)
83
- object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
84
  self.iterator = super().__iter__()
85
 
86
  def __len__(self):
@@ -90,19 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
90
  for i in range(len(self)):
91
  yield next(self.iterator)
92
 
93
- class _RepeatSampler(object):
94
- """ Sampler that repeats forever.
95
 
96
- Args:
97
- sampler (Sampler)
98
- """
99
 
100
- def __init__(self, sampler):
101
- self.sampler = sampler
 
102
 
103
- def __iter__(self):
104
- while True:
105
- yield from iter(self.sampler)
 
 
 
106
 
107
 
108
  class LoadImages: # for inference
 
68
  num_workers=nw,
69
  sampler=sampler,
70
  pin_memory=True,
71
+ collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
72
  return dataloader, dataset
73
 
74
 
 
80
 
81
  def __init__(self, *args, **kwargs):
82
  super().__init__(*args, **kwargs)
83
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
84
  self.iterator = super().__iter__()
85
 
86
  def __len__(self):
 
90
  for i in range(len(self)):
91
  yield next(self.iterator)
92
 
 
 
93
 
94
+ class _RepeatSampler(object):
95
+ """ Sampler that repeats forever.
 
96
 
97
+ Args:
98
+ sampler (Sampler)
99
+ """
100
 
101
+ def __init__(self, sampler):
102
+ self.sampler = sampler
103
+
104
+ def __iter__(self):
105
+ while True:
106
+ yield from iter(self.sampler)
107
 
108
 
109
  class LoadImages: # for inference