lucytuan commited on
Commit
e3d53d5
·
2 Parent(s): 2ad038a 5ce12fa

Merge branch 'DATASET' of github.com:LucyTuan/yolov9mit into DATASET

Browse files
config/data/augmentation.yaml CHANGED
@@ -1,3 +1,3 @@
1
- RandomHorizontalFlip: 0.5
2
- RandomVerticalFlip: 0.5
3
- Mosaic: 0.5
 
1
+ Mosaic: 1
2
+ MixUp: 1
3
+ RandomHorizontalFlip: 0.5
utils/{dataargument.py → data_augment.py} RENAMED
@@ -2,6 +2,7 @@ from PIL import Image
2
  import numpy as np
3
  import torch
4
  from torchvision.transforms import functional as TF
 
5
 
6
 
7
  class Compose:
@@ -90,3 +91,37 @@ class Mosaic:
90
 
91
  all_labels = torch.cat(all_labels, dim=0)
92
  return mosaic_image, all_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import torch
4
  from torchvision.transforms import functional as TF
5
+ from torchvision.transforms.functional import to_tensor, to_pil_image
6
 
7
 
8
  class Compose:
 
91
 
92
  all_labels = torch.cat(all_labels, dim=0)
93
  return mosaic_image, all_labels
94
+
95
+
96
+ class MixUp:
97
+ """Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
98
+
99
+ def __init__(self, prob=0.5, alpha=1.0):
100
+ self.alpha = alpha
101
+ self.prob = prob
102
+ self.parent = None
103
+
104
+ def set_parent(self, parent):
105
+ """Set the parent dataset object for accessing dataset methods."""
106
+ self.parent = parent
107
+
108
+ def __call__(self, image, boxes):
109
+ if torch.rand(1) >= self.prob:
110
+ return image, boxes
111
+
112
+ assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
113
+
114
+ # Retrieve another image and its boxes randomly from the dataset
115
+ image2, boxes2 = self.parent.get_more_data()[0]
116
+
117
+ # Calculate the mixup lambda parameter
118
+ lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
119
+
120
+ # Mix images
121
+ image1, image2 = to_tensor(image), to_tensor(image2)
122
+ mixed_image = lam * image1 + (1 - lam) * image2
123
+
124
+ # Mix bounding boxes
125
+ mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
126
+
127
+ return to_pil_image(mixed_image), mixed_boxes
utils/dataloader.py CHANGED
@@ -10,7 +10,7 @@ from tqdm.rich import tqdm
10
  import diskcache as dc
11
  from typing import Union
12
  from drawer import draw_bboxes
13
- from dataargument import Compose, RandomHorizontalFlip, Mosaic, RandomVerticalFlip
14
 
15
 
16
  class YoloDataset(Dataset):
 
10
  import diskcache as dc
11
  from typing import Union
12
  from drawer import draw_bboxes
13
+ from data_augment import Compose, RandomHorizontalFlip, RandomVerticalFlip, Mosaic, MixUp
14
 
15
 
16
  class YoloDataset(Dataset):