Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from math import sin, cos, pi | |
import numbers | |
import random | |
""" | |
Data augmentation functions. | |
There are some problems with torchvision data augmentation functions: | |
1. they work only on PIL images, which means they cannot be applied to tensors with more than 3 channels, | |
and they require a lot of conversion from Numpy -> PIL -> Tensor | |
2. they do not provide access to the internal transformations (affine matrices) used, which prevent | |
applying them for more complex tasks, such as transformation of an optic flow field (for which | |
the inverse transformation must be known). | |
For these reasons, we implement my own data augmentation functions | |
(strongly inspired by torchvision transforms) that operate directly | |
on Torch Tensor variables, and that allow to transform an optic flow field as well. | |
""" | |
class Compose(object): | |
"""Composes several transforms together. | |
Args: | |
transforms (list of ``Transform`` objects): list of transforms to compose. | |
Example: | |
>>> transforms.Compose([ | |
>>> transforms.CenterCrop(10), | |
>>> transforms.ToTensor(), | |
>>> ]) | |
""" | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, x, is_flow=False): | |
for t in self.transforms: | |
x = t(x, is_flow) | |
return x | |
def __repr__(self): | |
format_string = self.__class__.__name__ + '(' | |
for t in self.transforms: | |
format_string += '\n' | |
format_string += ' {0}'.format(t) | |
format_string += '\n)' | |
return format_string | |
class CenterCrop(object): | |
"""Center crop the tensor to a certain size. | |
""" | |
def __init__(self, size, preserve_mosaicing_pattern=False): | |
if isinstance(size, numbers.Number): | |
self.size = (int(size), int(size)) | |
else: | |
self.size = size | |
self.preserve_mosaicing_pattern = preserve_mosaicing_pattern | |
def __call__(self, x, is_flow=False): | |
""" | |
x: [C x H x W] Tensor to be rotated. | |
is_flow: this parameter does not have any effect | |
Returns: | |
Tensor: Cropped tensor. | |
""" | |
w, h = x.shape[2], x.shape[1] | |
th, tw = self.size | |
# print(tw, w) | |
# print(th, h) | |
assert(th <= h) | |
assert(tw <= w) | |
i = int(round((h - th) / 2.)) | |
j = int(round((w - tw) / 2.)) | |
if self.preserve_mosaicing_pattern: | |
# make sure that i and j are even, to preserve | |
# the mosaicing pattern | |
if i % 2 == 1: | |
i = i + 1 | |
if j % 2 == 1: | |
j = j + 1 | |
return x[:, i:i + th, j:j + tw] | |
def __repr__(self): | |
return self.__class__.__name__ + '(size={0})'.format(self.size) | |
class RandomCrop(object): | |
"""Crop the tensor at a random location. | |
""" | |
def __init__(self, size, preserve_mosaicing_pattern=False): | |
if isinstance(size, numbers.Number): | |
self.size = (int(size), int(size)) | |
else: | |
self.size = size | |
self.preserve_mosaicing_pattern = preserve_mosaicing_pattern | |
def get_params(x, output_size): | |
w, h = x.shape[2], x.shape[1] | |
th, tw = output_size | |
assert(th <= h) | |
assert(tw <= w) | |
if w == tw and h == th: | |
return 0, 0, h, w | |
i = random.randint(0, h - th) | |
j = random.randint(0, w - tw) | |
return i, j, th, tw | |
def __call__(self, x, is_flow=False): | |
""" | |
x: [C x H x W] Tensor to be rotated. | |
is_flow: this parameter does not have any effect | |
Returns: | |
Tensor: Cropped tensor. | |
""" | |
i, j, h, w = self.get_params(x, self.size) | |
if self.preserve_mosaicing_pattern: | |
# make sure that i and j are even, to preserve the mosaicing pattern | |
if i % 2 == 1: | |
i = i + 1 | |
if j % 2 == 1: | |
j = j + 1 | |
return x[:, i:i + h, j:j + w] | |
def __repr__(self): | |
return self.__class__.__name__ + '(size={0})'.format(self.size) | |
class RandomRotationFlip(object): | |
"""Rotate the image by angle. | |
""" | |
def __init__(self, degrees, p_hflip=0.5, p_vflip=0.5): | |
if isinstance(degrees, numbers.Number): | |
if degrees < 0: | |
raise ValueError("If degrees is a single number, it must be positive.") | |
self.degrees = (-degrees, degrees) | |
else: | |
if len(degrees) != 2: | |
raise ValueError("If degrees is a sequence, it must be of len 2.") | |
self.degrees = degrees | |
self.p_hflip = p_hflip | |
self.p_vflip = p_vflip | |
def get_params(degrees, p_hflip, p_vflip): | |
"""Get parameters for ``rotate`` for a random rotation. | |
Returns: | |
sequence: params to be passed to ``rotate`` for random rotation. | |
""" | |
angle = random.uniform(degrees[0], degrees[1]) | |
angle_rad = angle * pi / 180.0 | |
M_original_transformed = torch.FloatTensor([[cos(angle_rad), -sin(angle_rad), 0], | |
[sin(angle_rad), cos(angle_rad), 0], | |
[0, 0, 1]]) | |
if random.random() < p_hflip: | |
M_original_transformed[:, 0] *= -1 | |
if random.random() < p_vflip: | |
M_original_transformed[:, 1] *= -1 | |
M_transformed_original = torch.inverse(M_original_transformed) | |
M_original_transformed = M_original_transformed[:2, :].unsqueeze(dim=0) # 3 x 3 -> N x 2 x 3 | |
M_transformed_original = M_transformed_original[:2, :].unsqueeze(dim=0) | |
return M_original_transformed, M_transformed_original | |
def __call__(self, x, is_flow=False): | |
""" | |
x: [C x H x W] Tensor to be rotated. | |
is_flow: if True, x is an [2 x H x W] displacement field, which will also be transformed | |
Returns: | |
Tensor: Rotated tensor. | |
""" | |
assert(len(x.shape) == 3) | |
if is_flow: | |
assert(x.shape[0] == 2) | |
M_original_transformed, M_transformed_original = self.get_params(self.degrees, self.p_hflip, self.p_vflip) | |
affine_grid = F.affine_grid(M_original_transformed, x.unsqueeze(dim=0).shape, align_corners=False) | |
transformed = F.grid_sample(x.unsqueeze(dim=0), affine_grid, align_corners=False) | |
if is_flow: | |
# Apply the same transformation to the flow field | |
A00 = M_transformed_original[0, 0, 0] | |
A01 = M_transformed_original[0, 0, 1] | |
A10 = M_transformed_original[0, 1, 0] | |
A11 = M_transformed_original[0, 1, 1] | |
vx = transformed[:, 0, :, :].clone() | |
vy = transformed[:, 1, :, :].clone() | |
transformed[:, 0, :, :] = A00 * vx + A01 * vy | |
transformed[:, 1, :, :] = A10 * vx + A11 * vy | |
return transformed.squeeze(dim=0) | |
def __repr__(self): | |
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) | |
format_string += ', p_flip={:.2f}'.format(self.p_hflip) | |
format_string += ', p_vlip={:.2f}'.format(self.p_vflip) | |
format_string += ')' | |
return format_string |