🗃️ [Update] dataloader, remove YOLOLoader allin1
Browse files- yolo/tools/data_loader.py +36 -44
yolo/tools/data_loader.py
CHANGED
@@ -143,67 +143,59 @@ class YoloDataset(Dataset):
|
|
143 |
def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
|
144 |
img, bboxes, img_path = self.get_data(idx)
|
145 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
|
|
|
|
146 |
return img, bboxes, rev_tensor, img_path
|
147 |
|
148 |
def __len__(self) -> int:
|
149 |
return len(self.data)
|
150 |
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
dataset = YoloDataset(data_cfg, dataset_cfg, task)
|
156 |
-
sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None
|
157 |
-
self.image_size = data_cfg.image_size[0]
|
158 |
-
super().__init__(
|
159 |
-
dataset,
|
160 |
-
batch_size=data_cfg.batch_size,
|
161 |
-
sampler=sampler,
|
162 |
-
shuffle=data_cfg.shuffle and not use_ddp,
|
163 |
-
num_workers=data_cfg.cpu_num,
|
164 |
-
pin_memory=data_cfg.pin_memory,
|
165 |
-
collate_fn=self.collate_fn,
|
166 |
-
)
|
167 |
-
|
168 |
-
def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|
169 |
-
"""
|
170 |
-
A collate function to handle batching of images and their corresponding targets.
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
batch_targets[:, :, 1:] *= self.image_size
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
|
196 |
-
|
197 |
|
198 |
|
199 |
-
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"
|
200 |
if task == "inference":
|
201 |
return StreamDataLoader(data_cfg)
|
202 |
|
203 |
if dataset_cfg.auto_download:
|
204 |
prepare_dataset(dataset_cfg, task)
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
|
209 |
class StreamDataLoader:
|
|
|
143 |
def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
|
144 |
img, bboxes, img_path = self.get_data(idx)
|
145 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
146 |
+
bboxes[:, [1, 3]] *= self.image_size[0]
|
147 |
+
bboxes[:, [2, 4]] *= self.image_size[1]
|
148 |
return img, bboxes, rev_tensor, img_path
|
149 |
|
150 |
def __len__(self) -> int:
|
151 |
return len(self.data)
|
152 |
|
153 |
|
154 |
+
def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|
155 |
+
"""
|
156 |
+
A collate function to handle batching of images and their corresponding targets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
+
Args:
|
159 |
+
batch (list of tuples): Each tuple contains:
|
160 |
+
- image (Tensor): The image tensor.
|
161 |
+
- labels (Tensor): The tensor of labels for the image.
|
162 |
|
163 |
+
Returns:
|
164 |
+
Tuple[Tensor, List[Tensor]]: A tuple containing:
|
165 |
+
- A tensor of batched images.
|
166 |
+
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
167 |
+
"""
|
168 |
+
batch_size = len(batch)
|
169 |
+
target_sizes = [item[1].size(0) for item in batch]
|
170 |
+
# TODO: Improve readability of these proccess
|
171 |
+
# TODO: remove maxBbox or reduce loss function memory usage
|
172 |
+
batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
|
173 |
+
batch_targets[:, :, 0] = -1
|
174 |
+
for idx, target_size in enumerate(target_sizes):
|
175 |
+
batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
|
|
|
176 |
|
177 |
+
batch_images, _, batch_reverse, batch_path = zip(*batch)
|
178 |
+
batch_images = torch.stack(batch_images)
|
179 |
+
batch_reverse = torch.stack(batch_reverse)
|
180 |
|
181 |
+
return batch_size, batch_images, batch_targets, batch_reverse, batch_path
|
182 |
|
183 |
|
184 |
+
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
|
185 |
if task == "inference":
|
186 |
return StreamDataLoader(data_cfg)
|
187 |
|
188 |
if dataset_cfg.auto_download:
|
189 |
prepare_dataset(dataset_cfg, task)
|
190 |
+
dataset = YoloDataset(data_cfg, dataset_cfg, task)
|
191 |
+
|
192 |
+
return DataLoader(
|
193 |
+
dataset,
|
194 |
+
batch_size=data_cfg.batch_size,
|
195 |
+
num_workers=data_cfg.cpu_num,
|
196 |
+
pin_memory=data_cfg.pin_memory,
|
197 |
+
collate_fn=collate_fn,
|
198 |
+
)
|
199 |
|
200 |
|
201 |
class StreamDataLoader:
|