π [Fix] bugs in dynamic shape in training
Browse files- yolo/tools/data_augmentation.py +3 -2
- yolo/tools/data_loader.py +10 -6
- yolo/tools/solver.py +1 -0
yolo/tools/data_augmentation.py
CHANGED
@@ -9,10 +9,11 @@ from torchvision.transforms import functional as TF
|
|
9 |
class AugmentationComposer:
|
10 |
"""Composes several transforms together."""
|
11 |
|
12 |
-
def __init__(self, transforms, image_size: int = [640, 640]):
|
13 |
self.transforms = transforms
|
14 |
# TODO: handle List of image_size [640, 640]
|
15 |
self.pad_resize = PadAndResize(image_size)
|
|
|
16 |
|
17 |
for transform in self.transforms:
|
18 |
if hasattr(transform, "set_parent"):
|
@@ -122,7 +123,7 @@ class Mosaic:
|
|
122 |
|
123 |
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
124 |
|
125 |
-
img_sz = self.parent.
|
126 |
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
127 |
|
128 |
data = [(image, boxes)] + more_data
|
|
|
9 |
class AugmentationComposer:
|
10 |
"""Composes several transforms together."""
|
11 |
|
12 |
+
def __init__(self, transforms, image_size: int = [640, 640], base_size: int = 640):
|
13 |
self.transforms = transforms
|
14 |
# TODO: handle List of image_size [640, 640]
|
15 |
self.pad_resize = PadAndResize(image_size)
|
16 |
+
self.base_size = base_size
|
17 |
|
18 |
for transform in self.transforms:
|
19 |
if hasattr(transform, "set_parent"):
|
|
|
123 |
|
124 |
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
125 |
|
126 |
+
img_sz = self.parent.base_size # Assuming `image_size` is defined in parent
|
127 |
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
128 |
|
129 |
data = [(image, boxes)] + more_data
|
yolo/tools/data_loader.py
CHANGED
@@ -30,11 +30,11 @@ class YoloDataset(Dataset):
|
|
30 |
self.image_size = data_cfg.image_size
|
31 |
phase_name = dataset_cfg.get(phase, phase)
|
32 |
self.batch_size = data_cfg.batch_size
|
33 |
-
self.dynamic_shape = getattr(data_cfg, "dynamic_shape",
|
34 |
self.base_size = mean(self.image_size)
|
35 |
|
36 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
37 |
-
self.transform = AugmentationComposer(transforms, self.image_size)
|
38 |
self.transform.get_more_data = self.get_more_data
|
39 |
self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
|
40 |
|
@@ -53,20 +53,21 @@ class YoloDataset(Dataset):
|
|
53 |
|
54 |
if not cache_path.exists():
|
55 |
logger.info(f":factory: Generating {phase_name} cache")
|
56 |
-
data = self.filter_data(dataset_path, phase_name)
|
57 |
torch.save(data, cache_path)
|
58 |
else:
|
59 |
data = torch.load(cache_path, weights_only=False)
|
60 |
logger.info(f":package: Loaded {phase_name} cache")
|
61 |
return data
|
62 |
|
63 |
-
def filter_data(self, dataset_path: Path, phase_name: str) -> list:
|
64 |
"""
|
65 |
Filters and collects dataset information by pairing images with their corresponding labels.
|
66 |
|
67 |
Parameters:
|
68 |
images_path (Path): Path to the directory containing image files.
|
69 |
labels_path (str): Path to the directory containing label files.
|
|
|
70 |
|
71 |
Returns:
|
72 |
list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
|
@@ -105,8 +106,11 @@ class YoloDataset(Dataset):
|
|
105 |
labels = self.load_valid_labels(image_id, image_seg_annotations)
|
106 |
|
107 |
img_path = images_path / image_name
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
110 |
data.append((img_path, labels, width / height))
|
111 |
valid_inputs += 1
|
112 |
|
|
|
30 |
self.image_size = data_cfg.image_size
|
31 |
phase_name = dataset_cfg.get(phase, phase)
|
32 |
self.batch_size = data_cfg.batch_size
|
33 |
+
self.dynamic_shape = getattr(data_cfg, "dynamic_shape", False)
|
34 |
self.base_size = mean(self.image_size)
|
35 |
|
36 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
37 |
+
self.transform = AugmentationComposer(transforms, self.image_size, self.base_size)
|
38 |
self.transform.get_more_data = self.get_more_data
|
39 |
self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
|
40 |
|
|
|
53 |
|
54 |
if not cache_path.exists():
|
55 |
logger.info(f":factory: Generating {phase_name} cache")
|
56 |
+
data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
|
57 |
torch.save(data, cache_path)
|
58 |
else:
|
59 |
data = torch.load(cache_path, weights_only=False)
|
60 |
logger.info(f":package: Loaded {phase_name} cache")
|
61 |
return data
|
62 |
|
63 |
+
def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = False) -> list:
|
64 |
"""
|
65 |
Filters and collects dataset information by pairing images with their corresponding labels.
|
66 |
|
67 |
Parameters:
|
68 |
images_path (Path): Path to the directory containing image files.
|
69 |
labels_path (str): Path to the directory containing label files.
|
70 |
+
sort_image (bool): If True, sorts the dataset by the width-to-height ratio of images in descending order.
|
71 |
|
72 |
Returns:
|
73 |
list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
|
|
|
106 |
labels = self.load_valid_labels(image_id, image_seg_annotations)
|
107 |
|
108 |
img_path = images_path / image_name
|
109 |
+
if sort_image:
|
110 |
+
with Image.open(img_path) as img:
|
111 |
+
width, height = img.size
|
112 |
+
else:
|
113 |
+
width, height = 0, 1
|
114 |
data.append((img_path, labels, width / height))
|
115 |
valid_inputs += 1
|
116 |
|
yolo/tools/solver.py
CHANGED
@@ -85,6 +85,7 @@ class TrainModel(ValidateModel):
|
|
85 |
|
86 |
def on_train_epoch_start(self):
|
87 |
self.trainer.optimizers[0].next_epoch(ceil(len(self.train_loader) / self.trainer.world_size))
|
|
|
88 |
|
89 |
def training_step(self, batch, batch_idx):
|
90 |
lr_dict = self.trainer.optimizers[0].next_batch()
|
|
|
85 |
|
86 |
def on_train_epoch_start(self):
|
87 |
self.trainer.optimizers[0].next_epoch(ceil(len(self.train_loader) / self.trainer.world_size))
|
88 |
+
self.vec2box.update(self.cfg.image_size)
|
89 |
|
90 |
def training_step(self, batch, batch_idx):
|
91 |
lr_dict = self.trainer.optimizers[0].next_batch()
|