henry000 commited on
Commit
70a1547
Β·
1 Parent(s): a44a5c1

πŸ› [Add] Tensorlize @ dataloader

Browse files
Files changed (1) hide show
  1. yolo/tools/data_loader.py +15 -3
yolo/tools/data_loader.py CHANGED
@@ -32,7 +32,19 @@ class YoloDataset(Dataset):
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
- self.data = self.load_data(Path(dataset_cfg.path), phase_name)
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def load_data(self, dataset_path: Path, phase_name: str):
38
  """
@@ -132,7 +144,7 @@ class YoloDataset(Dataset):
132
  return torch.zeros((0, 5))
133
 
134
  def get_data(self, idx):
135
- img_path, bboxes = self.data[idx]
136
  img = Image.open(img_path).convert("RGB")
137
  return img, bboxes, img_path
138
 
@@ -146,7 +158,7 @@ class YoloDataset(Dataset):
146
  return img, bboxes, rev_tensor, img_path
147
 
148
  def __len__(self) -> int:
149
- return len(self.data)
150
 
151
 
152
  class YoloDataLoader(DataLoader):
 
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
+ self.img_paths, self.bboxes = self.tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
36
+
37
+ def tensorlize(self, data):
38
+ img_paths, bboxes = zip(*data)
39
+ max_box = max(bbox.size(0) for bbox in bboxes)
40
+ padded_bbox_list = []
41
+ for bbox in bboxes:
42
+ padding = torch.full((max_box, 5), -1, dtype=torch.float32)
43
+ padding[: bbox.size(0)] = bbox
44
+ padded_bbox_list.append(padding)
45
+ bboxes = torch.stack(padded_bbox_list)
46
+ img_paths = np.array(img_paths)
47
+ return img_paths, bboxes
48
 
49
  def load_data(self, dataset_path: Path, phase_name: str):
50
  """
 
144
  return torch.zeros((0, 5))
145
 
146
  def get_data(self, idx):
147
+ img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
148
  img = Image.open(img_path).convert("RGB")
149
  return img, bboxes, img_path
150
 
 
158
  return img, bboxes, rev_tensor, img_path
159
 
160
  def __len__(self) -> int:
161
+ return len(self.bboxes)
162
 
163
 
164
  class YoloDataLoader(DataLoader):