π [Merge] remote-tracking branch 'origin/DDP_BUGS' into Lightning
Browse files- yolo/tools/data_loader.py +33 -3
yolo/tools/data_loader.py
CHANGED
@@ -22,6 +22,34 @@ from yolo.utils.dataset_utils import (
|
|
22 |
from yolo.utils.logger import logger
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class YoloDataset(Dataset):
|
26 |
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
|
27 |
augment_cfg = data_cfg.data_augment
|
@@ -31,7 +59,8 @@ class YoloDataset(Dataset):
|
|
31 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
32 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
33 |
self.transform.get_more_data = self.get_more_data
|
34 |
-
|
|
|
35 |
|
36 |
def load_data(self, dataset_path: Path, phase_name: str):
|
37 |
"""
|
@@ -132,9 +161,10 @@ class YoloDataset(Dataset):
|
|
132 |
|
133 |
def get_data(self, idx):
|
134 |
img_path, bboxes = self.data[idx]
|
|
|
135 |
with Image.open(img_path) as img:
|
136 |
img = img.convert("RGB")
|
137 |
-
return img, bboxes, img_path
|
138 |
|
139 |
def get_more_data(self, num: int = 1):
|
140 |
indices = torch.randint(0, len(self), (num,))
|
@@ -148,7 +178,7 @@ class YoloDataset(Dataset):
|
|
148 |
return img, bboxes, rev_tensor, img_path
|
149 |
|
150 |
def __len__(self) -> int:
|
151 |
-
return len(self.
|
152 |
|
153 |
|
154 |
def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|
|
|
22 |
from yolo.utils.logger import logger
|
23 |
|
24 |
|
25 |
+
def tensorlize(data):
|
26 |
+
# TODO Move Tensorlize to helper
|
27 |
+
img_paths, bboxes = zip(*data)
|
28 |
+
max_box = max(bbox.size(0) for bbox in bboxes)
|
29 |
+
padded_bbox_list = []
|
30 |
+
for bbox in bboxes:
|
31 |
+
padding = torch.full((max_box, 5), -1, dtype=torch.float32)
|
32 |
+
padding[: bbox.size(0)] = bbox
|
33 |
+
padded_bbox_list.append(padding)
|
34 |
+
bboxes = np.stack(padded_bbox_list)
|
35 |
+
img_paths = np.array(img_paths)
|
36 |
+
return img_paths, bboxes
|
37 |
+
|
38 |
+
|
39 |
+
def tensorlize(data):
|
40 |
+
# TODO Move Tensorlize to helper
|
41 |
+
img_paths, bboxes = zip(*data)
|
42 |
+
max_box = max(bbox.size(0) for bbox in bboxes)
|
43 |
+
padded_bbox_list = []
|
44 |
+
for bbox in bboxes:
|
45 |
+
padding = torch.full((max_box, 5), -1, dtype=torch.float32)
|
46 |
+
padding[: bbox.size(0)] = bbox
|
47 |
+
padded_bbox_list.append(padding)
|
48 |
+
bboxes = np.stack(padded_bbox_list)
|
49 |
+
img_paths = np.array(img_paths)
|
50 |
+
return img_paths, bboxes
|
51 |
+
|
52 |
+
|
53 |
class YoloDataset(Dataset):
|
54 |
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
|
55 |
augment_cfg = data_cfg.data_augment
|
|
|
59 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
60 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
61 |
self.transform.get_more_data = self.get_more_data
|
62 |
+
img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
|
63 |
+
self.img_paths, self.bboxes = img_paths, bboxes
|
64 |
|
65 |
def load_data(self, dataset_path: Path, phase_name: str):
|
66 |
"""
|
|
|
161 |
|
162 |
def get_data(self, idx):
|
163 |
img_path, bboxes = self.data[idx]
|
164 |
+
valid_mask = bboxes[:, 0] != -1
|
165 |
with Image.open(img_path) as img:
|
166 |
img = img.convert("RGB")
|
167 |
+
return img, torch.from_numpy(bboxes[valid_mask]), img_path
|
168 |
|
169 |
def get_more_data(self, num: int = 1):
|
170 |
indices = torch.randint(0, len(self), (num,))
|
|
|
178 |
return img, bboxes, rev_tensor, img_path
|
179 |
|
180 |
def __len__(self) -> int:
|
181 |
+
return len(self.bboxes)
|
182 |
|
183 |
|
184 |
def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|