henry000 commited on
Commit
95520e9
·
1 Parent(s): 1d2b161

🗃️ [Update] dataloader, remove YOLOLoader allin1

Browse files
Files changed (1) hide show
  1. yolo/tools/data_loader.py +36 -44
yolo/tools/data_loader.py CHANGED
@@ -143,67 +143,59 @@ class YoloDataset(Dataset):
143
  def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
144
  img, bboxes, img_path = self.get_data(idx)
145
  img, bboxes, rev_tensor = self.transform(img, bboxes)
 
 
146
  return img, bboxes, rev_tensor, img_path
147
 
148
  def __len__(self) -> int:
149
  return len(self.data)
150
 
151
 
152
- class YoloDataLoader(DataLoader):
153
- def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
154
- """Initializes the YoloDataLoader with hydra-config files."""
155
- dataset = YoloDataset(data_cfg, dataset_cfg, task)
156
- sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None
157
- self.image_size = data_cfg.image_size[0]
158
- super().__init__(
159
- dataset,
160
- batch_size=data_cfg.batch_size,
161
- sampler=sampler,
162
- shuffle=data_cfg.shuffle and not use_ddp,
163
- num_workers=data_cfg.cpu_num,
164
- pin_memory=data_cfg.pin_memory,
165
- collate_fn=self.collate_fn,
166
- )
167
-
168
- def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
169
- """
170
- A collate function to handle batching of images and their corresponding targets.
171
 
172
- Args:
173
- batch (list of tuples): Each tuple contains:
174
- - image (Tensor): The image tensor.
175
- - labels (Tensor): The tensor of labels for the image.
176
 
177
- Returns:
178
- Tuple[Tensor, List[Tensor]]: A tuple containing:
179
- - A tensor of batched images.
180
- - A list of tensors, each corresponding to bboxes for each image in the batch.
181
- """
182
- batch_size = len(batch)
183
- target_sizes = [item[1].size(0) for item in batch]
184
- # TODO: Improve readability of these proccess
185
- # TODO: remove maxBbox or reduce loss function memory usage
186
- batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
187
- batch_targets[:, :, 0] = -1
188
- for idx, target_size in enumerate(target_sizes):
189
- batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
190
- batch_targets[:, :, 1:] *= self.image_size
191
 
192
- batch_images, _, batch_reverse, batch_path = zip(*batch)
193
- batch_images = torch.stack(batch_images)
194
- batch_reverse = torch.stack(batch_reverse)
195
 
196
- return batch_size, batch_images, batch_targets, batch_reverse, batch_path
197
 
198
 
199
- def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
200
  if task == "inference":
201
  return StreamDataLoader(data_cfg)
202
 
203
  if dataset_cfg.auto_download:
204
  prepare_dataset(dataset_cfg, task)
205
-
206
- return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp)
 
 
 
 
 
 
 
207
 
208
 
209
  class StreamDataLoader:
 
143
  def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
144
  img, bboxes, img_path = self.get_data(idx)
145
  img, bboxes, rev_tensor = self.transform(img, bboxes)
146
+ bboxes[:, [1, 3]] *= self.image_size[0]
147
+ bboxes[:, [2, 4]] *= self.image_size[1]
148
  return img, bboxes, rev_tensor, img_path
149
 
150
  def __len__(self) -> int:
151
  return len(self.data)
152
 
153
 
154
+ def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
155
+ """
156
+ A collate function to handle batching of images and their corresponding targets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ Args:
159
+ batch (list of tuples): Each tuple contains:
160
+ - image (Tensor): The image tensor.
161
+ - labels (Tensor): The tensor of labels for the image.
162
 
163
+ Returns:
164
+ Tuple[Tensor, List[Tensor]]: A tuple containing:
165
+ - A tensor of batched images.
166
+ - A list of tensors, each corresponding to bboxes for each image in the batch.
167
+ """
168
+ batch_size = len(batch)
169
+ target_sizes = [item[1].size(0) for item in batch]
170
+ # TODO: Improve readability of these proccess
171
+ # TODO: remove maxBbox or reduce loss function memory usage
172
+ batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
173
+ batch_targets[:, :, 0] = -1
174
+ for idx, target_size in enumerate(target_sizes):
175
+ batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
 
176
 
177
+ batch_images, _, batch_reverse, batch_path = zip(*batch)
178
+ batch_images = torch.stack(batch_images)
179
+ batch_reverse = torch.stack(batch_reverse)
180
 
181
+ return batch_size, batch_images, batch_targets, batch_reverse, batch_path
182
 
183
 
184
+ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
185
  if task == "inference":
186
  return StreamDataLoader(data_cfg)
187
 
188
  if dataset_cfg.auto_download:
189
  prepare_dataset(dataset_cfg, task)
190
+ dataset = YoloDataset(data_cfg, dataset_cfg, task)
191
+
192
+ return DataLoader(
193
+ dataset,
194
+ batch_size=data_cfg.batch_size,
195
+ num_workers=data_cfg.cpu_num,
196
+ pin_memory=data_cfg.pin_memory,
197
+ collate_fn=collate_fn,
198
+ )
199
 
200
 
201
  class StreamDataLoader: