henry000 commited on
Commit
e09ff86
Β·
1 Parent(s): 5be4e0e

πŸš› [Move] Tensorlize out of YoloDataset

Browse files
Files changed (1) hide show
  1. yolo/tools/data_loader.py +18 -14
yolo/tools/data_loader.py CHANGED
@@ -23,6 +23,20 @@ from yolo.utils.dataset_utils import (
23
  )
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class YoloDataset(Dataset):
27
  def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
28
  augment_cfg = data_cfg.data_augment
@@ -32,19 +46,8 @@ class YoloDataset(Dataset):
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
- self.img_paths, self.bboxes = self.tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
36
-
37
- def tensorlize(self, data):
38
- img_paths, bboxes = zip(*data)
39
- max_box = max(bbox.size(0) for bbox in bboxes)
40
- padded_bbox_list = []
41
- for bbox in bboxes:
42
- padding = torch.full((max_box, 5), -1, dtype=torch.float32)
43
- padding[: bbox.size(0)] = bbox
44
- padded_bbox_list.append(padding)
45
- bboxes = torch.stack(padded_bbox_list)
46
- img_paths = np.array(img_paths)
47
- return img_paths, bboxes
48
 
49
  def load_data(self, dataset_path: Path, phase_name: str):
50
  """
@@ -145,8 +148,9 @@ class YoloDataset(Dataset):
145
 
146
  def get_data(self, idx):
147
  img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
 
148
  img = Image.open(img_path).convert("RGB")
149
- return img, bboxes, img_path
150
 
151
  def get_more_data(self, num: int = 1):
152
  indices = torch.randint(0, len(self), (num,))
 
23
  )
24
 
25
 
26
+ def tensorlize(data):
27
+ # TODO Move Tensorlize to helper
28
+ img_paths, bboxes = zip(*data)
29
+ max_box = max(bbox.size(0) for bbox in bboxes)
30
+ padded_bbox_list = []
31
+ for bbox in bboxes:
32
+ padding = torch.full((max_box, 5), -1, dtype=torch.float32)
33
+ padding[: bbox.size(0)] = bbox
34
+ padded_bbox_list.append(padding)
35
+ bboxes = np.stack(padded_bbox_list)
36
+ img_paths = np.array(img_paths)
37
+ return img_paths, bboxes
38
+
39
+
40
  class YoloDataset(Dataset):
41
  def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
42
  augment_cfg = data_cfg.data_augment
 
46
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
47
  self.transform = AugmentationComposer(transforms, self.image_size)
48
  self.transform.get_more_data = self.get_more_data
49
+ img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
50
+ self.img_paths, self.bboxes = img_paths, bboxes
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def load_data(self, dataset_path: Path, phase_name: str):
53
  """
 
148
 
149
  def get_data(self, idx):
150
  img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
151
+ valid_mask = bboxes[:, 0] != -1
152
  img = Image.open(img_path).convert("RGB")
153
+ return img, torch.from_numpy(bboxes[valid_mask]), img_path
154
 
155
  def get_more_data(self, num: int = 1):
156
  indices = torch.randint(0, len(self), (num,))