🔨 [Update] dataloader, return data augment info
Browse files- yolo/tools/data_augmentation.py +23 -18
- yolo/tools/data_loader.py +13 -9
- yolo/tools/solver.py +5 -4
yolo/tools/data_augmentation.py
CHANGED
@@ -10,7 +10,7 @@ class AugmentationComposer:
|
|
10 |
def __init__(self, transforms, image_size: int = [640, 640]):
|
11 |
self.transforms = transforms
|
12 |
# TODO: handle List of image_size [640, 640]
|
13 |
-
self.image_size = image_size
|
14 |
self.pad_resize = PadAndResize(self.image_size)
|
15 |
|
16 |
for transform in self.transforms:
|
@@ -29,27 +29,32 @@ class AugmentationComposer:
|
|
29 |
|
30 |
|
31 |
class PadAndResize:
|
32 |
-
def __init__(self, image_size):
|
33 |
"""Initialize the object with the target image size."""
|
34 |
-
self.
|
|
|
35 |
|
36 |
-
def __call__(self, image, boxes):
|
37 |
-
|
38 |
-
scale = self.
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
square_img.paste(image, (left, top))
|
43 |
|
44 |
-
|
|
|
|
|
|
|
45 |
|
46 |
-
boxes[:, 1]
|
47 |
-
boxes[:, 2]
|
48 |
-
boxes[:, 3]
|
49 |
-
boxes[:, 4]
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
return
|
53 |
|
54 |
|
55 |
class HorizontalFlip:
|
@@ -94,7 +99,7 @@ class Mosaic:
|
|
94 |
|
95 |
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
96 |
|
97 |
-
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
|
98 |
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
99 |
|
100 |
data = [(image, boxes)] + more_data
|
|
|
10 |
def __init__(self, transforms, image_size: int = [640, 640]):
|
11 |
self.transforms = transforms
|
12 |
# TODO: handle List of image_size [640, 640]
|
13 |
+
self.image_size = image_size
|
14 |
self.pad_resize = PadAndResize(self.image_size)
|
15 |
|
16 |
for transform in self.transforms:
|
|
|
29 |
|
30 |
|
31 |
class PadAndResize:
|
32 |
+
def __init__(self, image_size, background_color=(128, 128, 128)):
|
33 |
"""Initialize the object with the target image size."""
|
34 |
+
self.target_width, self.target_height = image_size
|
35 |
+
self.background_color = background_color
|
36 |
|
37 |
+
def __call__(self, image: Image, boxes):
|
38 |
+
img_width, img_height = image.size
|
39 |
+
scale = min(self.target_width / img_width, self.target_height / img_height)
|
40 |
+
new_width, new_height = int(img_width * scale), int(img_height * scale)
|
41 |
+
|
42 |
+
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
|
43 |
|
44 |
+
pad_left = (self.target_width - new_width) // 2
|
45 |
+
pad_top = (self.target_height - new_height) // 2
|
46 |
+
padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color)
|
47 |
+
padded_image.paste(resized_image, (pad_left, pad_top))
|
48 |
|
49 |
+
boxes[:, 1] *= scale # xmin
|
50 |
+
boxes[:, 2] *= scale # ymin
|
51 |
+
boxes[:, 3] *= scale # xmax
|
52 |
+
boxes[:, 4] *= scale # ymax
|
53 |
+
boxes[:, [1, 3]] += pad_left
|
54 |
+
boxes[:, [2, 4]] += pad_top
|
55 |
|
56 |
+
transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top])
|
57 |
+
return padded_image, boxes, transform_info
|
58 |
|
59 |
|
60 |
class HorizontalFlip:
|
|
|
99 |
|
100 |
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
101 |
|
102 |
+
img_sz = self.parent.image_size[0] # Assuming `image_size` is defined in parent
|
103 |
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
104 |
|
105 |
data = [(image, boxes)] + more_data
|
yolo/tools/data_loader.py
CHANGED
@@ -141,16 +141,16 @@ class YoloDataset(Dataset):
|
|
141 |
def get_data(self, idx):
|
142 |
img_path, bboxes = self.data[idx]
|
143 |
img = Image.open(img_path).convert("RGB")
|
144 |
-
return img, bboxes
|
145 |
|
146 |
def get_more_data(self, num: int = 1):
|
147 |
indices = torch.randint(0, len(self), (num,))
|
148 |
-
return [self.get_data(idx) for idx in indices]
|
149 |
|
150 |
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
151 |
-
img, bboxes = self.get_data(idx)
|
152 |
-
img, bboxes,
|
153 |
-
return img, bboxes
|
154 |
|
155 |
def __len__(self) -> int:
|
156 |
return len(self.data)
|
@@ -195,9 +195,11 @@ class YoloDataLoader(DataLoader):
|
|
195 |
batch_targets[idx, :target_size] = batch[idx][1]
|
196 |
batch_targets[:, :, 1:] *= self.image_size
|
197 |
|
198 |
-
batch_images
|
|
|
|
|
199 |
|
200 |
-
return batch_images, batch_targets
|
201 |
|
202 |
|
203 |
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
|
@@ -261,12 +263,14 @@ class StreamDataLoader:
|
|
261 |
if isinstance(frame, np.ndarray):
|
262 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
263 |
frame = Image.fromarray(frame)
|
|
|
264 |
frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
|
265 |
frame = frame[None]
|
|
|
266 |
if not self.is_stream:
|
267 |
-
self.queue.put(frame)
|
268 |
else:
|
269 |
-
self.current_frame = frame
|
270 |
|
271 |
def __iter__(self) -> Generator[Tensor, None, None]:
|
272 |
return self
|
|
|
141 |
def get_data(self, idx):
|
142 |
img_path, bboxes = self.data[idx]
|
143 |
img = Image.open(img_path).convert("RGB")
|
144 |
+
return img, bboxes, img_path
|
145 |
|
146 |
def get_more_data(self, num: int = 1):
|
147 |
indices = torch.randint(0, len(self), (num,))
|
148 |
+
return [self.get_data(idx)[:2] for idx in indices]
|
149 |
|
150 |
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
151 |
+
img, bboxes, img_path = self.get_data(idx)
|
152 |
+
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
153 |
+
return img, bboxes, rev_tensor, img_path
|
154 |
|
155 |
def __len__(self) -> int:
|
156 |
return len(self.data)
|
|
|
195 |
batch_targets[idx, :target_size] = batch[idx][1]
|
196 |
batch_targets[:, :, 1:] *= self.image_size
|
197 |
|
198 |
+
batch_images, _, batch_reverse, batch_path = zip(*batch)
|
199 |
+
batch_images = torch.stack(batch_images)
|
200 |
+
batch_reverse = torch.stack(batch_reverse)
|
201 |
|
202 |
+
return batch_images, batch_targets, batch_reverse, batch_path
|
203 |
|
204 |
|
205 |
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
|
|
|
263 |
if isinstance(frame, np.ndarray):
|
264 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
265 |
frame = Image.fromarray(frame)
|
266 |
+
origin_frame = frame
|
267 |
frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
|
268 |
frame = frame[None]
|
269 |
+
rev_tensor = rev_tensor[None]
|
270 |
if not self.is_stream:
|
271 |
+
self.queue.put((frame, rev_tensor, origin_frame))
|
272 |
else:
|
273 |
+
self.current_frame = (frame, rev_tensor, origin_frame)
|
274 |
|
275 |
def __iter__(self) -> Generator[Tensor, None, None]:
|
276 |
return self
|
yolo/tools/solver.py
CHANGED
@@ -72,7 +72,7 @@ class ModelTrainer:
|
|
72 |
self.model.train()
|
73 |
total_loss = 0
|
74 |
|
75 |
-
for images, targets in dataloader:
|
76 |
loss, loss_each = self.train_one_batch(images, targets)
|
77 |
|
78 |
total_loss += loss
|
@@ -136,8 +136,9 @@ class ModelTester:
|
|
136 |
|
137 |
last_time = time.time()
|
138 |
try:
|
139 |
-
for idx, images in enumerate(dataloader):
|
140 |
images = images.to(self.device)
|
|
|
141 |
with torch.no_grad():
|
142 |
predicts = self.model(images)
|
143 |
predicts = self.vec2box(predicts["Main"])
|
@@ -192,8 +193,8 @@ class ModelValidator:
|
|
192 |
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
193 |
map_all = []
|
194 |
self.progress.start_one_epoch(len(dataloader))
|
195 |
-
for images, targets in dataloader:
|
196 |
-
images, targets = images.to(self.device), targets.to(self.device)
|
197 |
with torch.no_grad():
|
198 |
predicts = self.model(images)
|
199 |
predicts = self.vec2box(predicts["Main"])
|
|
|
72 |
self.model.train()
|
73 |
total_loss = 0
|
74 |
|
75 |
+
for images, targets, *_ in dataloader:
|
76 |
loss, loss_each = self.train_one_batch(images, targets)
|
77 |
|
78 |
total_loss += loss
|
|
|
136 |
|
137 |
last_time = time.time()
|
138 |
try:
|
139 |
+
for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
|
140 |
images = images.to(self.device)
|
141 |
+
rev_tensor = rev_tensor.to(self.device)
|
142 |
with torch.no_grad():
|
143 |
predicts = self.model(images)
|
144 |
predicts = self.vec2box(predicts["Main"])
|
|
|
193 |
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
194 |
map_all = []
|
195 |
self.progress.start_one_epoch(len(dataloader))
|
196 |
+
for images, targets, rev_tensor, img_paths in dataloader:
|
197 |
+
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
|
198 |
with torch.no_grad():
|
199 |
predicts = self.model(images)
|
200 |
predicts = self.vec2box(predicts["Main"])
|