✨ [Add] a Mosaic data augment in dataloader
Browse files- config/data/augmentation.yaml +2 -1
- utils/dataargument.py +57 -6
- utils/dataloader.py +14 -3
config/data/augmentation.yaml
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
RandomHorizontalFlip: 0.5
|
|
|
|
1 |
+
RandomHorizontalFlip: 0.5
|
2 |
+
Mosaic: 0.5
|
utils/dataargument.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
from torchvision.transforms import functional as TF
|
3 |
|
@@ -8,21 +10,70 @@ class Compose:
|
|
8 |
def __init__(self, transforms):
|
9 |
self.transforms = transforms
|
10 |
|
|
|
|
|
|
|
|
|
11 |
def __call__(self, image, boxes):
|
12 |
-
for
|
13 |
-
image, boxes =
|
14 |
return image, boxes
|
15 |
|
|
|
|
|
|
|
16 |
|
17 |
class RandomHorizontalFlip:
|
18 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
19 |
|
20 |
-
def __init__(self,
|
21 |
-
self.
|
22 |
|
23 |
def __call__(self, image, boxes):
|
24 |
-
if torch.rand(1) < self.
|
25 |
image = TF.hflip(image)
|
26 |
-
# Assuming boxes are in the format [cls, xmin, ymin, xmax, ymax]
|
27 |
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
|
28 |
return image, boxes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
import torch
|
4 |
from torchvision.transforms import functional as TF
|
5 |
|
|
|
10 |
def __init__(self, transforms):
|
11 |
self.transforms = transforms
|
12 |
|
13 |
+
for transform in self.transforms:
|
14 |
+
if hasattr(transform, "set_parent"):
|
15 |
+
transform.set_parent(self)
|
16 |
+
|
17 |
def __call__(self, image, boxes):
|
18 |
+
for transform in self.transforms:
|
19 |
+
image, boxes = transform(image, boxes)
|
20 |
return image, boxes
|
21 |
|
22 |
+
def get_more_data(self):
|
23 |
+
raise NotImplementedError("This method should be overridden by subclass instances!")
|
24 |
+
|
25 |
|
26 |
class RandomHorizontalFlip:
|
27 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
28 |
|
29 |
+
def __init__(self, prob=0.5):
|
30 |
+
self.prob = prob
|
31 |
|
32 |
def __call__(self, image, boxes):
|
33 |
+
if torch.rand(1) < self.prob:
|
34 |
image = TF.hflip(image)
|
|
|
35 |
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
|
36 |
return image, boxes
|
37 |
+
|
38 |
+
|
39 |
+
class Mosaic:
|
40 |
+
"""Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
|
41 |
+
|
42 |
+
def __init__(self, prob=0.5):
|
43 |
+
self.prob = prob
|
44 |
+
self.parent = None
|
45 |
+
|
46 |
+
def set_parent(self, parent):
|
47 |
+
self.parent = parent
|
48 |
+
|
49 |
+
def __call__(self, image, boxes):
|
50 |
+
if torch.rand(1) >= self.prob:
|
51 |
+
return image, boxes
|
52 |
+
|
53 |
+
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
54 |
+
|
55 |
+
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
|
56 |
+
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
57 |
+
|
58 |
+
data = [(image, boxes)] + more_data
|
59 |
+
mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
|
60 |
+
vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
|
61 |
+
center = np.array([img_sz, img_sz])
|
62 |
+
all_labels = []
|
63 |
+
|
64 |
+
for (image, boxes), vector in zip(data, vectors):
|
65 |
+
this_w, this_h = image.size
|
66 |
+
coord = tuple(center + vector * np.array([this_w, this_h]))
|
67 |
+
|
68 |
+
mosaic_image.paste(image, coord)
|
69 |
+
xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
|
70 |
+
xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
|
71 |
+
xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
|
72 |
+
ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
|
73 |
+
ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
|
74 |
+
|
75 |
+
adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
|
76 |
+
all_labels.append(adjusted_boxes)
|
77 |
+
|
78 |
+
all_labels = torch.cat(all_labels, dim=0)
|
79 |
+
return mosaic_image, all_labels
|
utils/dataloader.py
CHANGED
@@ -10,14 +10,17 @@ from tqdm.rich import tqdm
|
|
10 |
import diskcache as dc
|
11 |
from typing import Union
|
12 |
from drawer import draw_bboxes
|
13 |
-
from dataargument import Compose, RandomHorizontalFlip
|
14 |
|
15 |
|
16 |
class YoloDataset(Dataset):
|
17 |
-
def __init__(self, dataset_cfg: dict, phase: str = "train", transform=None):
|
18 |
phase_name = dataset_cfg.get(phase, phase)
|
|
|
19 |
|
20 |
self.transform = transform
|
|
|
|
|
21 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
22 |
|
23 |
def load_data(self, dataset_path, phase_name):
|
@@ -105,9 +108,17 @@ class YoloDataset(Dataset):
|
|
105 |
logger.warning("No valid BBox in {}", label_path)
|
106 |
return None
|
107 |
|
108 |
-
def
|
109 |
img_path, bboxes = self.data[idx]
|
110 |
img = Image.open(img_path).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if self.transform:
|
112 |
img, bboxes = self.transform(img, bboxes)
|
113 |
return img, bboxes
|
|
|
10 |
import diskcache as dc
|
11 |
from typing import Union
|
12 |
from drawer import draw_bboxes
|
13 |
+
from dataargument import Compose, RandomHorizontalFlip, Mosaic
|
14 |
|
15 |
|
16 |
class YoloDataset(Dataset):
|
17 |
+
def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
|
18 |
phase_name = dataset_cfg.get(phase, phase)
|
19 |
+
self.image_size = image_size
|
20 |
|
21 |
self.transform = transform
|
22 |
+
self.transform.get_more_data = self.get_more_data
|
23 |
+
self.transform.image_size = self.image_size
|
24 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
25 |
|
26 |
def load_data(self, dataset_path, phase_name):
|
|
|
108 |
logger.warning("No valid BBox in {}", label_path)
|
109 |
return None
|
110 |
|
111 |
+
def get_data(self, idx):
|
112 |
img_path, bboxes = self.data[idx]
|
113 |
img = Image.open(img_path).convert("RGB")
|
114 |
+
return img, bboxes
|
115 |
+
|
116 |
+
def get_more_data(self, num: int = 1):
|
117 |
+
indices = torch.randint(0, len(self), (num,))
|
118 |
+
return [self.get_data(idx) for idx in indices]
|
119 |
+
|
120 |
+
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
121 |
+
img, bboxes = self.get_data(idx)
|
122 |
if self.transform:
|
123 |
img, bboxes = self.transform(img, bboxes)
|
124 |
return img, bboxes
|