henry000 commited on
Commit
8b1b21f
·
1 Parent(s): 2b2044d

🔨 [Update] dataloader, return data augment info

Browse files
yolo/tools/data_augmentation.py CHANGED
@@ -10,7 +10,7 @@ class AugmentationComposer:
10
  def __init__(self, transforms, image_size: int = [640, 640]):
11
  self.transforms = transforms
12
  # TODO: handle List of image_size [640, 640]
13
- self.image_size = image_size[0]
14
  self.pad_resize = PadAndResize(self.image_size)
15
 
16
  for transform in self.transforms:
@@ -29,27 +29,32 @@ class AugmentationComposer:
29
 
30
 
31
  class PadAndResize:
32
- def __init__(self, image_size):
33
  """Initialize the object with the target image size."""
34
- self.image_size = image_size
 
35
 
36
- def __call__(self, image, boxes):
37
- original_size = max(image.size)
38
- scale = self.image_size / original_size
39
- square_img = Image.new("RGB", (original_size, original_size), (128, 128, 128))
40
- left = (original_size - image.width) // 2
41
- top = (original_size - image.height) // 2
42
- square_img.paste(image, (left, top))
43
 
44
- resized_img = square_img.resize((self.image_size, self.image_size))
 
 
 
45
 
46
- boxes[:, 1] = (boxes[:, 1] * image.width + left) / self.image_size * scale
47
- boxes[:, 2] = (boxes[:, 2] * image.height + top) / self.image_size * scale
48
- boxes[:, 3] = (boxes[:, 3] * image.width + left) / self.image_size * scale
49
- boxes[:, 4] = (boxes[:, 4] * image.height + top) / self.image_size * scale
 
 
50
 
51
- rev_tensor = torch.tensor([scale, left, top, left, top])
52
- return resized_img, boxes, rev_tensor
53
 
54
 
55
  class HorizontalFlip:
@@ -94,7 +99,7 @@ class Mosaic:
94
 
95
  assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
96
 
97
- img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
98
  more_data = self.parent.get_more_data(3) # get 3 more images randomly
99
 
100
  data = [(image, boxes)] + more_data
 
10
  def __init__(self, transforms, image_size: int = [640, 640]):
11
  self.transforms = transforms
12
  # TODO: handle List of image_size [640, 640]
13
+ self.image_size = image_size
14
  self.pad_resize = PadAndResize(self.image_size)
15
 
16
  for transform in self.transforms:
 
29
 
30
 
31
  class PadAndResize:
32
+ def __init__(self, image_size, background_color=(128, 128, 128)):
33
  """Initialize the object with the target image size."""
34
+ self.target_width, self.target_height = image_size
35
+ self.background_color = background_color
36
 
37
+ def __call__(self, image: Image, boxes):
38
+ img_width, img_height = image.size
39
+ scale = min(self.target_width / img_width, self.target_height / img_height)
40
+ new_width, new_height = int(img_width * scale), int(img_height * scale)
41
+
42
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
 
43
 
44
+ pad_left = (self.target_width - new_width) // 2
45
+ pad_top = (self.target_height - new_height) // 2
46
+ padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color)
47
+ padded_image.paste(resized_image, (pad_left, pad_top))
48
 
49
+ boxes[:, 1] *= scale # xmin
50
+ boxes[:, 2] *= scale # ymin
51
+ boxes[:, 3] *= scale # xmax
52
+ boxes[:, 4] *= scale # ymax
53
+ boxes[:, [1, 3]] += pad_left
54
+ boxes[:, [2, 4]] += pad_top
55
 
56
+ transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top])
57
+ return padded_image, boxes, transform_info
58
 
59
 
60
  class HorizontalFlip:
 
99
 
100
  assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
101
 
102
+ img_sz = self.parent.image_size[0] # Assuming `image_size` is defined in parent
103
  more_data = self.parent.get_more_data(3) # get 3 more images randomly
104
 
105
  data = [(image, boxes)] + more_data
yolo/tools/data_loader.py CHANGED
@@ -141,16 +141,16 @@ class YoloDataset(Dataset):
141
  def get_data(self, idx):
142
  img_path, bboxes = self.data[idx]
143
  img = Image.open(img_path).convert("RGB")
144
- return img, bboxes
145
 
146
  def get_more_data(self, num: int = 1):
147
  indices = torch.randint(0, len(self), (num,))
148
- return [self.get_data(idx) for idx in indices]
149
 
150
  def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
151
- img, bboxes = self.get_data(idx)
152
- img, bboxes, _ = self.transform(img, bboxes)
153
- return img, bboxes
154
 
155
  def __len__(self) -> int:
156
  return len(self.data)
@@ -195,9 +195,11 @@ class YoloDataLoader(DataLoader):
195
  batch_targets[idx, :target_size] = batch[idx][1]
196
  batch_targets[:, :, 1:] *= self.image_size
197
 
198
- batch_images = torch.stack([item[0] for item in batch])
 
 
199
 
200
- return batch_images, batch_targets
201
 
202
 
203
  def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
@@ -261,12 +263,14 @@ class StreamDataLoader:
261
  if isinstance(frame, np.ndarray):
262
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
263
  frame = Image.fromarray(frame)
 
264
  frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
265
  frame = frame[None]
 
266
  if not self.is_stream:
267
- self.queue.put(frame)
268
  else:
269
- self.current_frame = frame
270
 
271
  def __iter__(self) -> Generator[Tensor, None, None]:
272
  return self
 
141
  def get_data(self, idx):
142
  img_path, bboxes = self.data[idx]
143
  img = Image.open(img_path).convert("RGB")
144
+ return img, bboxes, img_path
145
 
146
  def get_more_data(self, num: int = 1):
147
  indices = torch.randint(0, len(self), (num,))
148
+ return [self.get_data(idx)[:2] for idx in indices]
149
 
150
  def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
151
+ img, bboxes, img_path = self.get_data(idx)
152
+ img, bboxes, rev_tensor = self.transform(img, bboxes)
153
+ return img, bboxes, rev_tensor, img_path
154
 
155
  def __len__(self) -> int:
156
  return len(self.data)
 
195
  batch_targets[idx, :target_size] = batch[idx][1]
196
  batch_targets[:, :, 1:] *= self.image_size
197
 
198
+ batch_images, _, batch_reverse, batch_path = zip(*batch)
199
+ batch_images = torch.stack(batch_images)
200
+ batch_reverse = torch.stack(batch_reverse)
201
 
202
+ return batch_images, batch_targets, batch_reverse, batch_path
203
 
204
 
205
  def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
 
263
  if isinstance(frame, np.ndarray):
264
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
265
  frame = Image.fromarray(frame)
266
+ origin_frame = frame
267
  frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
268
  frame = frame[None]
269
+ rev_tensor = rev_tensor[None]
270
  if not self.is_stream:
271
+ self.queue.put((frame, rev_tensor, origin_frame))
272
  else:
273
+ self.current_frame = (frame, rev_tensor, origin_frame)
274
 
275
  def __iter__(self) -> Generator[Tensor, None, None]:
276
  return self
yolo/tools/solver.py CHANGED
@@ -72,7 +72,7 @@ class ModelTrainer:
72
  self.model.train()
73
  total_loss = 0
74
 
75
- for images, targets in dataloader:
76
  loss, loss_each = self.train_one_batch(images, targets)
77
 
78
  total_loss += loss
@@ -136,8 +136,9 @@ class ModelTester:
136
 
137
  last_time = time.time()
138
  try:
139
- for idx, images in enumerate(dataloader):
140
  images = images.to(self.device)
 
141
  with torch.no_grad():
142
  predicts = self.model(images)
143
  predicts = self.vec2box(predicts["Main"])
@@ -192,8 +193,8 @@ class ModelValidator:
192
  iou_thresholds = torch.arange(0.5, 1.0, 0.05)
193
  map_all = []
194
  self.progress.start_one_epoch(len(dataloader))
195
- for images, targets in dataloader:
196
- images, targets = images.to(self.device), targets.to(self.device)
197
  with torch.no_grad():
198
  predicts = self.model(images)
199
  predicts = self.vec2box(predicts["Main"])
 
72
  self.model.train()
73
  total_loss = 0
74
 
75
+ for images, targets, *_ in dataloader:
76
  loss, loss_each = self.train_one_batch(images, targets)
77
 
78
  total_loss += loss
 
136
 
137
  last_time = time.time()
138
  try:
139
+ for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
140
  images = images.to(self.device)
141
+ rev_tensor = rev_tensor.to(self.device)
142
  with torch.no_grad():
143
  predicts = self.model(images)
144
  predicts = self.vec2box(predicts["Main"])
 
193
  iou_thresholds = torch.arange(0.5, 1.0, 0.05)
194
  map_all = []
195
  self.progress.start_one_epoch(len(dataloader))
196
+ for images, targets, rev_tensor, img_paths in dataloader:
197
+ images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
198
  with torch.no_grad():
199
  predicts = self.model(images)
200
  predicts = self.vec2box(predicts["Main"])