henry000 commited on
Commit
49d58b9
·
1 Parent(s): 06e6ab2

⚗️ [Add] MixUp augment, not sure it can work with Mosaic

Browse files
config/data/augmentation.yaml CHANGED
@@ -1,2 +1,3 @@
1
- RandomHorizontalFlip: 0.5
2
- Mosaic: 0.5
 
 
1
+ Mosaic: 1
2
+ MixUp: 1
3
+ RandomHorizontalFlip: 0.5
utils/data_augment.py CHANGED
@@ -2,6 +2,7 @@ from PIL import Image
2
  import numpy as np
3
  import torch
4
  from torchvision.transforms import functional as TF
 
5
 
6
 
7
  class Compose:
@@ -77,3 +78,37 @@ class Mosaic:
77
 
78
  all_labels = torch.cat(all_labels, dim=0)
79
  return mosaic_image, all_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import torch
4
  from torchvision.transforms import functional as TF
5
+ from torchvision.transforms.functional import to_tensor, to_pil_image
6
 
7
 
8
  class Compose:
 
78
 
79
  all_labels = torch.cat(all_labels, dim=0)
80
  return mosaic_image, all_labels
81
+
82
+
83
+ class MixUp:
84
+ """Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
85
+
86
+ def __init__(self, prob=0.5, alpha=1.0):
87
+ self.alpha = alpha
88
+ self.prob = prob
89
+ self.parent = None
90
+
91
+ def set_parent(self, parent):
92
+ """Set the parent dataset object for accessing dataset methods."""
93
+ self.parent = parent
94
+
95
+ def __call__(self, image, boxes):
96
+ if torch.rand(1) >= self.prob:
97
+ return image, boxes
98
+
99
+ assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
100
+
101
+ # Retrieve another image and its boxes randomly from the dataset
102
+ image2, boxes2 = self.parent.get_more_data()[0]
103
+
104
+ # Calculate the mixup lambda parameter
105
+ lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
106
+
107
+ # Mix images
108
+ image1, image2 = to_tensor(image), to_tensor(image2)
109
+ mixed_image = lam * image1 + (1 - lam) * image2
110
+
111
+ # Mix bounding boxes
112
+ mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
113
+
114
+ return to_pil_image(mixed_image), mixed_boxes
utils/dataloader.py CHANGED
@@ -10,7 +10,7 @@ from tqdm.rich import tqdm
10
  import diskcache as dc
11
  from typing import Union
12
  from drawer import draw_bboxes
13
- from data_augment import Compose, RandomHorizontalFlip, Mosaic
14
 
15
 
16
  class YoloDataset(Dataset):
 
10
  import diskcache as dc
11
  from typing import Union
12
  from drawer import draw_bboxes
13
+ from data_augment import Compose, RandomHorizontalFlip, Mosaic, MixUp
14
 
15
 
16
  class YoloDataset(Dataset):