zzzzzeee commited on
Commit
4bacde3
·
verified ·
1 Parent(s): 10ee601

Create data_augmentation.py

Browse files
Files changed (1) hide show
  1. data_augmentation.py +220 -0
data_augmentation.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import sin, cos, pi
4
+ import numbers
5
+ import random
6
+
7
+ """
8
+ Data augmentation functions.
9
+
10
+ There are some problems with torchvision data augmentation functions:
11
+ 1. they work only on PIL images, which means they cannot be applied to tensors with more than 3 channels,
12
+ and they require a lot of conversion from Numpy -> PIL -> Tensor
13
+
14
+ 2. they do not provide access to the internal transformations (affine matrices) used, which prevent
15
+ applying them for more complex tasks, such as transformation of an optic flow field (for which
16
+ the inverse transformation must be known).
17
+
18
+ For these reasons, we implement my own data augmentation functions
19
+ (strongly inspired by torchvision transforms) that operate directly
20
+ on Torch Tensor variables, and that allow to transform an optic flow field as well.
21
+ """
22
+
23
+
24
+ class Compose(object):
25
+ """Composes several transforms together.
26
+ Args:
27
+ transforms (list of ``Transform`` objects): list of transforms to compose.
28
+ Example:
29
+ >>> transforms.Compose([
30
+ >>> transforms.CenterCrop(10),
31
+ >>> transforms.ToTensor(),
32
+ >>> ])
33
+ """
34
+
35
+ def __init__(self, transforms):
36
+ self.transforms = transforms
37
+
38
+ def __call__(self, x, is_flow=False):
39
+ for t in self.transforms:
40
+ x = t(x, is_flow)
41
+ return x
42
+
43
+ def __repr__(self):
44
+ format_string = self.__class__.__name__ + '('
45
+ for t in self.transforms:
46
+ format_string += '\n'
47
+ format_string += ' {0}'.format(t)
48
+ format_string += '\n)'
49
+ return format_string
50
+
51
+
52
+ class CenterCrop(object):
53
+ """Center crop the tensor to a certain size.
54
+ """
55
+
56
+ def __init__(self, size, preserve_mosaicing_pattern=False):
57
+ if isinstance(size, numbers.Number):
58
+ self.size = (int(size), int(size))
59
+ else:
60
+ self.size = size
61
+
62
+ self.preserve_mosaicing_pattern = preserve_mosaicing_pattern
63
+
64
+ def __call__(self, x, is_flow=False):
65
+ """
66
+ x: [C x H x W] Tensor to be rotated.
67
+ is_flow: this parameter does not have any effect
68
+ Returns:
69
+ Tensor: Cropped tensor.
70
+ """
71
+
72
+ w, h = x.shape[2], x.shape[1]
73
+ th, tw = self.size
74
+ # print(tw, w)
75
+ # print(th, h)
76
+ assert(th <= h)
77
+ assert(tw <= w)
78
+ i = int(round((h - th) / 2.))
79
+ j = int(round((w - tw) / 2.))
80
+
81
+ if self.preserve_mosaicing_pattern:
82
+ # make sure that i and j are even, to preserve
83
+ # the mosaicing pattern
84
+ if i % 2 == 1:
85
+ i = i + 1
86
+ if j % 2 == 1:
87
+ j = j + 1
88
+
89
+ return x[:, i:i + th, j:j + tw]
90
+
91
+ def __repr__(self):
92
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
93
+
94
+
95
+ class RandomCrop(object):
96
+ """Crop the tensor at a random location.
97
+ """
98
+
99
+ def __init__(self, size, preserve_mosaicing_pattern=False):
100
+ if isinstance(size, numbers.Number):
101
+ self.size = (int(size), int(size))
102
+ else:
103
+ self.size = size
104
+
105
+ self.preserve_mosaicing_pattern = preserve_mosaicing_pattern
106
+
107
+ @staticmethod
108
+ def get_params(x, output_size):
109
+ w, h = x.shape[2], x.shape[1]
110
+ th, tw = output_size
111
+ assert(th <= h)
112
+ assert(tw <= w)
113
+ if w == tw and h == th:
114
+ return 0, 0, h, w
115
+
116
+ i = random.randint(0, h - th)
117
+ j = random.randint(0, w - tw)
118
+
119
+ return i, j, th, tw
120
+
121
+ def __call__(self, x, is_flow=False):
122
+ """
123
+ x: [C x H x W] Tensor to be rotated.
124
+ is_flow: this parameter does not have any effect
125
+ Returns:
126
+ Tensor: Cropped tensor.
127
+ """
128
+ i, j, h, w = self.get_params(x, self.size)
129
+
130
+ if self.preserve_mosaicing_pattern:
131
+ # make sure that i and j are even, to preserve the mosaicing pattern
132
+ if i % 2 == 1:
133
+ i = i + 1
134
+ if j % 2 == 1:
135
+ j = j + 1
136
+
137
+ return x[:, i:i + h, j:j + w]
138
+
139
+ def __repr__(self):
140
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
141
+
142
+
143
+ class RandomRotationFlip(object):
144
+ """Rotate the image by angle.
145
+ """
146
+
147
+ def __init__(self, degrees, p_hflip=0.5, p_vflip=0.5):
148
+ if isinstance(degrees, numbers.Number):
149
+ if degrees < 0:
150
+ raise ValueError("If degrees is a single number, it must be positive.")
151
+ self.degrees = (-degrees, degrees)
152
+ else:
153
+ if len(degrees) != 2:
154
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
155
+ self.degrees = degrees
156
+
157
+ self.p_hflip = p_hflip
158
+ self.p_vflip = p_vflip
159
+
160
+ @staticmethod
161
+ def get_params(degrees, p_hflip, p_vflip):
162
+ """Get parameters for ``rotate`` for a random rotation.
163
+ Returns:
164
+ sequence: params to be passed to ``rotate`` for random rotation.
165
+ """
166
+ angle = random.uniform(degrees[0], degrees[1])
167
+ angle_rad = angle * pi / 180.0
168
+
169
+ M_original_transformed = torch.FloatTensor([[cos(angle_rad), -sin(angle_rad), 0],
170
+ [sin(angle_rad), cos(angle_rad), 0],
171
+ [0, 0, 1]])
172
+
173
+ if random.random() < p_hflip:
174
+ M_original_transformed[:, 0] *= -1
175
+
176
+ if random.random() < p_vflip:
177
+ M_original_transformed[:, 1] *= -1
178
+
179
+ M_transformed_original = torch.inverse(M_original_transformed)
180
+
181
+ M_original_transformed = M_original_transformed[:2, :].unsqueeze(dim=0) # 3 x 3 -> N x 2 x 3
182
+ M_transformed_original = M_transformed_original[:2, :].unsqueeze(dim=0)
183
+
184
+ return M_original_transformed, M_transformed_original
185
+
186
+ def __call__(self, x, is_flow=False):
187
+ """
188
+ x: [C x H x W] Tensor to be rotated.
189
+ is_flow: if True, x is an [2 x H x W] displacement field, which will also be transformed
190
+ Returns:
191
+ Tensor: Rotated tensor.
192
+ """
193
+ assert(len(x.shape) == 3)
194
+
195
+ if is_flow:
196
+ assert(x.shape[0] == 2)
197
+
198
+ M_original_transformed, M_transformed_original = self.get_params(self.degrees, self.p_hflip, self.p_vflip)
199
+ affine_grid = F.affine_grid(M_original_transformed, x.unsqueeze(dim=0).shape, align_corners=False)
200
+ transformed = F.grid_sample(x.unsqueeze(dim=0), affine_grid, align_corners=False)
201
+
202
+ if is_flow:
203
+ # Apply the same transformation to the flow field
204
+ A00 = M_transformed_original[0, 0, 0]
205
+ A01 = M_transformed_original[0, 0, 1]
206
+ A10 = M_transformed_original[0, 1, 0]
207
+ A11 = M_transformed_original[0, 1, 1]
208
+ vx = transformed[:, 0, :, :].clone()
209
+ vy = transformed[:, 1, :, :].clone()
210
+ transformed[:, 0, :, :] = A00 * vx + A01 * vy
211
+ transformed[:, 1, :, :] = A10 * vx + A11 * vy
212
+
213
+ return transformed.squeeze(dim=0)
214
+
215
+ def __repr__(self):
216
+ format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
217
+ format_string += ', p_flip={:.2f}'.format(self.p_hflip)
218
+ format_string += ', p_vlip={:.2f}'.format(self.p_vflip)
219
+ format_string += ')'
220
+ return format_string