henry000 commited on
Commit
f789a66
Β·
2 Parent(s): 079a11c b9867bb

πŸ”€ [Merge] remote-tracking branch 'origin/DDP_BUGS' into Lightning

Browse files
Files changed (1) hide show
  1. yolo/tools/data_loader.py +33 -3
yolo/tools/data_loader.py CHANGED
@@ -22,6 +22,34 @@ from yolo.utils.dataset_utils import (
22
  from yolo.utils.logger import logger
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class YoloDataset(Dataset):
26
  def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
27
  augment_cfg = data_cfg.data_augment
@@ -31,7 +59,8 @@ class YoloDataset(Dataset):
31
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
32
  self.transform = AugmentationComposer(transforms, self.image_size)
33
  self.transform.get_more_data = self.get_more_data
34
- self.data = self.load_data(Path(dataset_cfg.path), phase_name)
 
35
 
36
  def load_data(self, dataset_path: Path, phase_name: str):
37
  """
@@ -132,9 +161,10 @@ class YoloDataset(Dataset):
132
 
133
  def get_data(self, idx):
134
  img_path, bboxes = self.data[idx]
 
135
  with Image.open(img_path) as img:
136
  img = img.convert("RGB")
137
- return img, bboxes, img_path
138
 
139
  def get_more_data(self, num: int = 1):
140
  indices = torch.randint(0, len(self), (num,))
@@ -148,7 +178,7 @@ class YoloDataset(Dataset):
148
  return img, bboxes, rev_tensor, img_path
149
 
150
  def __len__(self) -> int:
151
- return len(self.data)
152
 
153
 
154
  def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
 
22
  from yolo.utils.logger import logger
23
 
24
 
25
+ def tensorlize(data):
26
+ # TODO Move Tensorlize to helper
27
+ img_paths, bboxes = zip(*data)
28
+ max_box = max(bbox.size(0) for bbox in bboxes)
29
+ padded_bbox_list = []
30
+ for bbox in bboxes:
31
+ padding = torch.full((max_box, 5), -1, dtype=torch.float32)
32
+ padding[: bbox.size(0)] = bbox
33
+ padded_bbox_list.append(padding)
34
+ bboxes = np.stack(padded_bbox_list)
35
+ img_paths = np.array(img_paths)
36
+ return img_paths, bboxes
37
+
38
+
39
+ def tensorlize(data):
40
+ # TODO Move Tensorlize to helper
41
+ img_paths, bboxes = zip(*data)
42
+ max_box = max(bbox.size(0) for bbox in bboxes)
43
+ padded_bbox_list = []
44
+ for bbox in bboxes:
45
+ padding = torch.full((max_box, 5), -1, dtype=torch.float32)
46
+ padding[: bbox.size(0)] = bbox
47
+ padded_bbox_list.append(padding)
48
+ bboxes = np.stack(padded_bbox_list)
49
+ img_paths = np.array(img_paths)
50
+ return img_paths, bboxes
51
+
52
+
53
  class YoloDataset(Dataset):
54
  def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
55
  augment_cfg = data_cfg.data_augment
 
59
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
60
  self.transform = AugmentationComposer(transforms, self.image_size)
61
  self.transform.get_more_data = self.get_more_data
62
+ img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
63
+ self.img_paths, self.bboxes = img_paths, bboxes
64
 
65
  def load_data(self, dataset_path: Path, phase_name: str):
66
  """
 
161
 
162
  def get_data(self, idx):
163
  img_path, bboxes = self.data[idx]
164
+ valid_mask = bboxes[:, 0] != -1
165
  with Image.open(img_path) as img:
166
  img = img.convert("RGB")
167
+ return img, torch.from_numpy(bboxes[valid_mask]), img_path
168
 
169
  def get_more_data(self, num: int = 1):
170
  indices = torch.randint(0, len(self), (num,))
 
178
  return img, bboxes, rev_tensor, img_path
179
 
180
  def __len__(self) -> int:
181
+ return len(self.bboxes)
182
 
183
 
184
  def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]: