henry000 commited on
Commit
3092710
Β·
1 Parent(s): ad7078a

πŸ› [Fix] bugs in dynamic shape in training

Browse files
yolo/tools/data_augmentation.py CHANGED
@@ -9,10 +9,11 @@ from torchvision.transforms import functional as TF
9
  class AugmentationComposer:
10
  """Composes several transforms together."""
11
 
12
- def __init__(self, transforms, image_size: int = [640, 640]):
13
  self.transforms = transforms
14
  # TODO: handle List of image_size [640, 640]
15
  self.pad_resize = PadAndResize(image_size)
 
16
 
17
  for transform in self.transforms:
18
  if hasattr(transform, "set_parent"):
@@ -122,7 +123,7 @@ class Mosaic:
122
 
123
  assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
124
 
125
- img_sz = self.parent.image_size[0] # Assuming `image_size` is defined in parent
126
  more_data = self.parent.get_more_data(3) # get 3 more images randomly
127
 
128
  data = [(image, boxes)] + more_data
 
9
  class AugmentationComposer:
10
  """Composes several transforms together."""
11
 
12
+ def __init__(self, transforms, image_size: int = [640, 640], base_size: int = 640):
13
  self.transforms = transforms
14
  # TODO: handle List of image_size [640, 640]
15
  self.pad_resize = PadAndResize(image_size)
16
+ self.base_size = base_size
17
 
18
  for transform in self.transforms:
19
  if hasattr(transform, "set_parent"):
 
123
 
124
  assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
125
 
126
+ img_sz = self.parent.base_size # Assuming `image_size` is defined in parent
127
  more_data = self.parent.get_more_data(3) # get 3 more images randomly
128
 
129
  data = [(image, boxes)] + more_data
yolo/tools/data_loader.py CHANGED
@@ -30,11 +30,11 @@ class YoloDataset(Dataset):
30
  self.image_size = data_cfg.image_size
31
  phase_name = dataset_cfg.get(phase, phase)
32
  self.batch_size = data_cfg.batch_size
33
- self.dynamic_shape = getattr(data_cfg, "dynamic_shape", True)
34
  self.base_size = mean(self.image_size)
35
 
36
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
37
- self.transform = AugmentationComposer(transforms, self.image_size)
38
  self.transform.get_more_data = self.get_more_data
39
  self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
40
 
@@ -53,20 +53,21 @@ class YoloDataset(Dataset):
53
 
54
  if not cache_path.exists():
55
  logger.info(f":factory: Generating {phase_name} cache")
56
- data = self.filter_data(dataset_path, phase_name)
57
  torch.save(data, cache_path)
58
  else:
59
  data = torch.load(cache_path, weights_only=False)
60
  logger.info(f":package: Loaded {phase_name} cache")
61
  return data
62
 
63
- def filter_data(self, dataset_path: Path, phase_name: str) -> list:
64
  """
65
  Filters and collects dataset information by pairing images with their corresponding labels.
66
 
67
  Parameters:
68
  images_path (Path): Path to the directory containing image files.
69
  labels_path (str): Path to the directory containing label files.
 
70
 
71
  Returns:
72
  list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
@@ -105,8 +106,11 @@ class YoloDataset(Dataset):
105
  labels = self.load_valid_labels(image_id, image_seg_annotations)
106
 
107
  img_path = images_path / image_name
108
- with Image.open(img_path) as img:
109
- width, height = img.size
 
 
 
110
  data.append((img_path, labels, width / height))
111
  valid_inputs += 1
112
 
 
30
  self.image_size = data_cfg.image_size
31
  phase_name = dataset_cfg.get(phase, phase)
32
  self.batch_size = data_cfg.batch_size
33
+ self.dynamic_shape = getattr(data_cfg, "dynamic_shape", False)
34
  self.base_size = mean(self.image_size)
35
 
36
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
37
+ self.transform = AugmentationComposer(transforms, self.image_size, self.base_size)
38
  self.transform.get_more_data = self.get_more_data
39
  self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
40
 
 
53
 
54
  if not cache_path.exists():
55
  logger.info(f":factory: Generating {phase_name} cache")
56
+ data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
57
  torch.save(data, cache_path)
58
  else:
59
  data = torch.load(cache_path, weights_only=False)
60
  logger.info(f":package: Loaded {phase_name} cache")
61
  return data
62
 
63
+ def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = False) -> list:
64
  """
65
  Filters and collects dataset information by pairing images with their corresponding labels.
66
 
67
  Parameters:
68
  images_path (Path): Path to the directory containing image files.
69
  labels_path (str): Path to the directory containing label files.
70
+ sort_image (bool): If True, sorts the dataset by the width-to-height ratio of images in descending order.
71
 
72
  Returns:
73
  list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
 
106
  labels = self.load_valid_labels(image_id, image_seg_annotations)
107
 
108
  img_path = images_path / image_name
109
+ if sort_image:
110
+ with Image.open(img_path) as img:
111
+ width, height = img.size
112
+ else:
113
+ width, height = 0, 1
114
  data.append((img_path, labels, width / height))
115
  valid_inputs += 1
116
 
yolo/tools/solver.py CHANGED
@@ -85,6 +85,7 @@ class TrainModel(ValidateModel):
85
 
86
  def on_train_epoch_start(self):
87
  self.trainer.optimizers[0].next_epoch(ceil(len(self.train_loader) / self.trainer.world_size))
 
88
 
89
  def training_step(self, batch, batch_idx):
90
  lr_dict = self.trainer.optimizers[0].next_batch()
 
85
 
86
  def on_train_epoch_start(self):
87
  self.trainer.optimizers[0].next_epoch(ceil(len(self.train_loader) / self.trainer.world_size))
88
+ self.vec2box.update(self.cfg.image_size)
89
 
90
  def training_step(self, batch, batch_idx):
91
  lr_dict = self.trainer.optimizers[0].next_batch()