henry000 commited on
Commit
237de06
·
1 Parent(s): d8aafaa

✨ [Add] a Mosaic data augment in dataloader

Browse files
config/data/augmentation.yaml CHANGED
@@ -1 +1,2 @@
1
- RandomHorizontalFlip: 0.5
 
 
1
+ RandomHorizontalFlip: 0.5
2
+ Mosaic: 0.5
utils/dataargument.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from torchvision.transforms import functional as TF
3
 
@@ -8,21 +10,70 @@ class Compose:
8
  def __init__(self, transforms):
9
  self.transforms = transforms
10
 
 
 
 
 
11
  def __call__(self, image, boxes):
12
- for t in self.transforms:
13
- image, boxes = t(image, boxes)
14
  return image, boxes
15
 
 
 
 
16
 
17
  class RandomHorizontalFlip:
18
  """Randomly horizontally flips the image along with the bounding boxes."""
19
 
20
- def __init__(self, p=0.5):
21
- self.p = p
22
 
23
  def __call__(self, image, boxes):
24
- if torch.rand(1) < self.p:
25
  image = TF.hflip(image)
26
- # Assuming boxes are in the format [cls, xmin, ymin, xmax, ymax]
27
  boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
28
  return image, boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
  import torch
4
  from torchvision.transforms import functional as TF
5
 
 
10
  def __init__(self, transforms):
11
  self.transforms = transforms
12
 
13
+ for transform in self.transforms:
14
+ if hasattr(transform, "set_parent"):
15
+ transform.set_parent(self)
16
+
17
  def __call__(self, image, boxes):
18
+ for transform in self.transforms:
19
+ image, boxes = transform(image, boxes)
20
  return image, boxes
21
 
22
+ def get_more_data(self):
23
+ raise NotImplementedError("This method should be overridden by subclass instances!")
24
+
25
 
26
  class RandomHorizontalFlip:
27
  """Randomly horizontally flips the image along with the bounding boxes."""
28
 
29
+ def __init__(self, prob=0.5):
30
+ self.prob = prob
31
 
32
  def __call__(self, image, boxes):
33
+ if torch.rand(1) < self.prob:
34
  image = TF.hflip(image)
 
35
  boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
36
  return image, boxes
37
+
38
+
39
+ class Mosaic:
40
+ """Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
41
+
42
+ def __init__(self, prob=0.5):
43
+ self.prob = prob
44
+ self.parent = None
45
+
46
+ def set_parent(self, parent):
47
+ self.parent = parent
48
+
49
+ def __call__(self, image, boxes):
50
+ if torch.rand(1) >= self.prob:
51
+ return image, boxes
52
+
53
+ assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
54
+
55
+ img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
56
+ more_data = self.parent.get_more_data(3) # get 3 more images randomly
57
+
58
+ data = [(image, boxes)] + more_data
59
+ mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
60
+ vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
61
+ center = np.array([img_sz, img_sz])
62
+ all_labels = []
63
+
64
+ for (image, boxes), vector in zip(data, vectors):
65
+ this_w, this_h = image.size
66
+ coord = tuple(center + vector * np.array([this_w, this_h]))
67
+
68
+ mosaic_image.paste(image, coord)
69
+ xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
70
+ xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
71
+ xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
72
+ ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
73
+ ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
74
+
75
+ adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
76
+ all_labels.append(adjusted_boxes)
77
+
78
+ all_labels = torch.cat(all_labels, dim=0)
79
+ return mosaic_image, all_labels
utils/dataloader.py CHANGED
@@ -10,14 +10,17 @@ from tqdm.rich import tqdm
10
  import diskcache as dc
11
  from typing import Union
12
  from drawer import draw_bboxes
13
- from dataargument import Compose, RandomHorizontalFlip
14
 
15
 
16
  class YoloDataset(Dataset):
17
- def __init__(self, dataset_cfg: dict, phase: str = "train", transform=None):
18
  phase_name = dataset_cfg.get(phase, phase)
 
19
 
20
  self.transform = transform
 
 
21
  self.data = self.load_data(dataset_cfg.path, phase_name)
22
 
23
  def load_data(self, dataset_path, phase_name):
@@ -105,9 +108,17 @@ class YoloDataset(Dataset):
105
  logger.warning("No valid BBox in {}", label_path)
106
  return None
107
 
108
- def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
109
  img_path, bboxes = self.data[idx]
110
  img = Image.open(img_path).convert("RGB")
 
 
 
 
 
 
 
 
111
  if self.transform:
112
  img, bboxes = self.transform(img, bboxes)
113
  return img, bboxes
 
10
  import diskcache as dc
11
  from typing import Union
12
  from drawer import draw_bboxes
13
+ from dataargument import Compose, RandomHorizontalFlip, Mosaic
14
 
15
 
16
  class YoloDataset(Dataset):
17
+ def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
18
  phase_name = dataset_cfg.get(phase, phase)
19
+ self.image_size = image_size
20
 
21
  self.transform = transform
22
+ self.transform.get_more_data = self.get_more_data
23
+ self.transform.image_size = self.image_size
24
  self.data = self.load_data(dataset_cfg.path, phase_name)
25
 
26
  def load_data(self, dataset_path, phase_name):
 
108
  logger.warning("No valid BBox in {}", label_path)
109
  return None
110
 
111
+ def get_data(self, idx):
112
  img_path, bboxes = self.data[idx]
113
  img = Image.open(img_path).convert("RGB")
114
+ return img, bboxes
115
+
116
+ def get_more_data(self, num: int = 1):
117
+ indices = torch.randint(0, len(self), (num,))
118
+ return [self.get_data(idx) for idx in indices]
119
+
120
+ def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
121
+ img, bboxes = self.get_data(idx)
122
  if self.transform:
123
  img, bboxes = self.transform(img, bboxes)
124
  return img, bboxes