Spaces:
Runtime error
Runtime error
Create data_augmentation.py
Browse files- data_augmentation.py +220 -0
data_augmentation.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from math import sin, cos, pi
|
4 |
+
import numbers
|
5 |
+
import random
|
6 |
+
|
7 |
+
"""
|
8 |
+
Data augmentation functions.
|
9 |
+
|
10 |
+
There are some problems with torchvision data augmentation functions:
|
11 |
+
1. they work only on PIL images, which means they cannot be applied to tensors with more than 3 channels,
|
12 |
+
and they require a lot of conversion from Numpy -> PIL -> Tensor
|
13 |
+
|
14 |
+
2. they do not provide access to the internal transformations (affine matrices) used, which prevent
|
15 |
+
applying them for more complex tasks, such as transformation of an optic flow field (for which
|
16 |
+
the inverse transformation must be known).
|
17 |
+
|
18 |
+
For these reasons, we implement my own data augmentation functions
|
19 |
+
(strongly inspired by torchvision transforms) that operate directly
|
20 |
+
on Torch Tensor variables, and that allow to transform an optic flow field as well.
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
class Compose(object):
|
25 |
+
"""Composes several transforms together.
|
26 |
+
Args:
|
27 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
28 |
+
Example:
|
29 |
+
>>> transforms.Compose([
|
30 |
+
>>> transforms.CenterCrop(10),
|
31 |
+
>>> transforms.ToTensor(),
|
32 |
+
>>> ])
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, transforms):
|
36 |
+
self.transforms = transforms
|
37 |
+
|
38 |
+
def __call__(self, x, is_flow=False):
|
39 |
+
for t in self.transforms:
|
40 |
+
x = t(x, is_flow)
|
41 |
+
return x
|
42 |
+
|
43 |
+
def __repr__(self):
|
44 |
+
format_string = self.__class__.__name__ + '('
|
45 |
+
for t in self.transforms:
|
46 |
+
format_string += '\n'
|
47 |
+
format_string += ' {0}'.format(t)
|
48 |
+
format_string += '\n)'
|
49 |
+
return format_string
|
50 |
+
|
51 |
+
|
52 |
+
class CenterCrop(object):
|
53 |
+
"""Center crop the tensor to a certain size.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, size, preserve_mosaicing_pattern=False):
|
57 |
+
if isinstance(size, numbers.Number):
|
58 |
+
self.size = (int(size), int(size))
|
59 |
+
else:
|
60 |
+
self.size = size
|
61 |
+
|
62 |
+
self.preserve_mosaicing_pattern = preserve_mosaicing_pattern
|
63 |
+
|
64 |
+
def __call__(self, x, is_flow=False):
|
65 |
+
"""
|
66 |
+
x: [C x H x W] Tensor to be rotated.
|
67 |
+
is_flow: this parameter does not have any effect
|
68 |
+
Returns:
|
69 |
+
Tensor: Cropped tensor.
|
70 |
+
"""
|
71 |
+
|
72 |
+
w, h = x.shape[2], x.shape[1]
|
73 |
+
th, tw = self.size
|
74 |
+
# print(tw, w)
|
75 |
+
# print(th, h)
|
76 |
+
assert(th <= h)
|
77 |
+
assert(tw <= w)
|
78 |
+
i = int(round((h - th) / 2.))
|
79 |
+
j = int(round((w - tw) / 2.))
|
80 |
+
|
81 |
+
if self.preserve_mosaicing_pattern:
|
82 |
+
# make sure that i and j are even, to preserve
|
83 |
+
# the mosaicing pattern
|
84 |
+
if i % 2 == 1:
|
85 |
+
i = i + 1
|
86 |
+
if j % 2 == 1:
|
87 |
+
j = j + 1
|
88 |
+
|
89 |
+
return x[:, i:i + th, j:j + tw]
|
90 |
+
|
91 |
+
def __repr__(self):
|
92 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
93 |
+
|
94 |
+
|
95 |
+
class RandomCrop(object):
|
96 |
+
"""Crop the tensor at a random location.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, size, preserve_mosaicing_pattern=False):
|
100 |
+
if isinstance(size, numbers.Number):
|
101 |
+
self.size = (int(size), int(size))
|
102 |
+
else:
|
103 |
+
self.size = size
|
104 |
+
|
105 |
+
self.preserve_mosaicing_pattern = preserve_mosaicing_pattern
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def get_params(x, output_size):
|
109 |
+
w, h = x.shape[2], x.shape[1]
|
110 |
+
th, tw = output_size
|
111 |
+
assert(th <= h)
|
112 |
+
assert(tw <= w)
|
113 |
+
if w == tw and h == th:
|
114 |
+
return 0, 0, h, w
|
115 |
+
|
116 |
+
i = random.randint(0, h - th)
|
117 |
+
j = random.randint(0, w - tw)
|
118 |
+
|
119 |
+
return i, j, th, tw
|
120 |
+
|
121 |
+
def __call__(self, x, is_flow=False):
|
122 |
+
"""
|
123 |
+
x: [C x H x W] Tensor to be rotated.
|
124 |
+
is_flow: this parameter does not have any effect
|
125 |
+
Returns:
|
126 |
+
Tensor: Cropped tensor.
|
127 |
+
"""
|
128 |
+
i, j, h, w = self.get_params(x, self.size)
|
129 |
+
|
130 |
+
if self.preserve_mosaicing_pattern:
|
131 |
+
# make sure that i and j are even, to preserve the mosaicing pattern
|
132 |
+
if i % 2 == 1:
|
133 |
+
i = i + 1
|
134 |
+
if j % 2 == 1:
|
135 |
+
j = j + 1
|
136 |
+
|
137 |
+
return x[:, i:i + h, j:j + w]
|
138 |
+
|
139 |
+
def __repr__(self):
|
140 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
141 |
+
|
142 |
+
|
143 |
+
class RandomRotationFlip(object):
|
144 |
+
"""Rotate the image by angle.
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(self, degrees, p_hflip=0.5, p_vflip=0.5):
|
148 |
+
if isinstance(degrees, numbers.Number):
|
149 |
+
if degrees < 0:
|
150 |
+
raise ValueError("If degrees is a single number, it must be positive.")
|
151 |
+
self.degrees = (-degrees, degrees)
|
152 |
+
else:
|
153 |
+
if len(degrees) != 2:
|
154 |
+
raise ValueError("If degrees is a sequence, it must be of len 2.")
|
155 |
+
self.degrees = degrees
|
156 |
+
|
157 |
+
self.p_hflip = p_hflip
|
158 |
+
self.p_vflip = p_vflip
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def get_params(degrees, p_hflip, p_vflip):
|
162 |
+
"""Get parameters for ``rotate`` for a random rotation.
|
163 |
+
Returns:
|
164 |
+
sequence: params to be passed to ``rotate`` for random rotation.
|
165 |
+
"""
|
166 |
+
angle = random.uniform(degrees[0], degrees[1])
|
167 |
+
angle_rad = angle * pi / 180.0
|
168 |
+
|
169 |
+
M_original_transformed = torch.FloatTensor([[cos(angle_rad), -sin(angle_rad), 0],
|
170 |
+
[sin(angle_rad), cos(angle_rad), 0],
|
171 |
+
[0, 0, 1]])
|
172 |
+
|
173 |
+
if random.random() < p_hflip:
|
174 |
+
M_original_transformed[:, 0] *= -1
|
175 |
+
|
176 |
+
if random.random() < p_vflip:
|
177 |
+
M_original_transformed[:, 1] *= -1
|
178 |
+
|
179 |
+
M_transformed_original = torch.inverse(M_original_transformed)
|
180 |
+
|
181 |
+
M_original_transformed = M_original_transformed[:2, :].unsqueeze(dim=0) # 3 x 3 -> N x 2 x 3
|
182 |
+
M_transformed_original = M_transformed_original[:2, :].unsqueeze(dim=0)
|
183 |
+
|
184 |
+
return M_original_transformed, M_transformed_original
|
185 |
+
|
186 |
+
def __call__(self, x, is_flow=False):
|
187 |
+
"""
|
188 |
+
x: [C x H x W] Tensor to be rotated.
|
189 |
+
is_flow: if True, x is an [2 x H x W] displacement field, which will also be transformed
|
190 |
+
Returns:
|
191 |
+
Tensor: Rotated tensor.
|
192 |
+
"""
|
193 |
+
assert(len(x.shape) == 3)
|
194 |
+
|
195 |
+
if is_flow:
|
196 |
+
assert(x.shape[0] == 2)
|
197 |
+
|
198 |
+
M_original_transformed, M_transformed_original = self.get_params(self.degrees, self.p_hflip, self.p_vflip)
|
199 |
+
affine_grid = F.affine_grid(M_original_transformed, x.unsqueeze(dim=0).shape, align_corners=False)
|
200 |
+
transformed = F.grid_sample(x.unsqueeze(dim=0), affine_grid, align_corners=False)
|
201 |
+
|
202 |
+
if is_flow:
|
203 |
+
# Apply the same transformation to the flow field
|
204 |
+
A00 = M_transformed_original[0, 0, 0]
|
205 |
+
A01 = M_transformed_original[0, 0, 1]
|
206 |
+
A10 = M_transformed_original[0, 1, 0]
|
207 |
+
A11 = M_transformed_original[0, 1, 1]
|
208 |
+
vx = transformed[:, 0, :, :].clone()
|
209 |
+
vy = transformed[:, 1, :, :].clone()
|
210 |
+
transformed[:, 0, :, :] = A00 * vx + A01 * vy
|
211 |
+
transformed[:, 1, :, :] = A10 * vx + A11 * vy
|
212 |
+
|
213 |
+
return transformed.squeeze(dim=0)
|
214 |
+
|
215 |
+
def __repr__(self):
|
216 |
+
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
|
217 |
+
format_string += ', p_flip={:.2f}'.format(self.p_hflip)
|
218 |
+
format_string += ', p_vlip={:.2f}'.format(self.p_vflip)
|
219 |
+
format_string += ')'
|
220 |
+
return format_string
|