|
import torch |
|
import torch.nn.functional as F |
|
from math import sin, cos, pi |
|
import numbers |
|
import random |
|
|
|
""" |
|
Data augmentation functions. |
|
|
|
There are some problems with torchvision data augmentation functions: |
|
1. they work only on PIL images, which means they cannot be applied to tensors with more than 3 channels, |
|
and they require a lot of conversion from Numpy -> PIL -> Tensor |
|
|
|
2. they do not provide access to the internal transformations (affine matrices) used, which prevent |
|
applying them for more complex tasks, such as transformation of an optic flow field (for which |
|
the inverse transformation must be known). |
|
|
|
For these reasons, we implement my own data augmentation functions |
|
(strongly inspired by torchvision transforms) that operate directly |
|
on Torch Tensor variables, and that allow to transform an optic flow field as well. |
|
""" |
|
|
|
|
|
class Compose(object): |
|
"""Composes several transforms together. |
|
Args: |
|
transforms (list of ``Transform`` objects): list of transforms to compose. |
|
Example: |
|
>>> transforms.Compose([ |
|
>>> transforms.CenterCrop(10), |
|
>>> transforms.ToTensor(), |
|
>>> ]) |
|
""" |
|
|
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, x, is_flow=False): |
|
for t in self.transforms: |
|
x = t(x, is_flow) |
|
return x |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + '(' |
|
for t in self.transforms: |
|
format_string += '\n' |
|
format_string += ' {0}'.format(t) |
|
format_string += '\n)' |
|
return format_string |
|
|
|
|
|
class CenterCrop(object): |
|
"""Center crop the tensor to a certain size. |
|
""" |
|
|
|
def __init__(self, size, preserve_mosaicing_pattern=False): |
|
if isinstance(size, numbers.Number): |
|
self.size = (int(size), int(size)) |
|
else: |
|
self.size = size |
|
|
|
self.preserve_mosaicing_pattern = preserve_mosaicing_pattern |
|
|
|
def __call__(self, x, is_flow=False): |
|
""" |
|
x: [C x H x W] Tensor to be rotated. |
|
is_flow: this parameter does not have any effect |
|
Returns: |
|
Tensor: Cropped tensor. |
|
""" |
|
|
|
w, h = x.shape[2], x.shape[1] |
|
th, tw = self.size |
|
|
|
|
|
assert(th <= h) |
|
assert(tw <= w) |
|
i = int(round((h - th) / 2.)) |
|
j = int(round((w - tw) / 2.)) |
|
|
|
if self.preserve_mosaicing_pattern: |
|
|
|
|
|
if i % 2 == 1: |
|
i = i + 1 |
|
if j % 2 == 1: |
|
j = j + 1 |
|
|
|
return x[:, i:i + th, j:j + tw] |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(size={0})'.format(self.size) |
|
|
|
|
|
class RandomCrop(object): |
|
"""Crop the tensor at a random location. |
|
""" |
|
|
|
def __init__(self, size, preserve_mosaicing_pattern=False): |
|
if isinstance(size, numbers.Number): |
|
self.size = (int(size), int(size)) |
|
else: |
|
self.size = size |
|
|
|
self.preserve_mosaicing_pattern = preserve_mosaicing_pattern |
|
|
|
@staticmethod |
|
def get_params(x, output_size): |
|
w, h = x.shape[2], x.shape[1] |
|
th, tw = output_size |
|
assert(th <= h) |
|
assert(tw <= w) |
|
if w == tw and h == th: |
|
return 0, 0, h, w |
|
|
|
i = random.randint(0, h - th) |
|
j = random.randint(0, w - tw) |
|
|
|
return i, j, th, tw |
|
|
|
def __call__(self, x, is_flow=False): |
|
""" |
|
x: [C x H x W] Tensor to be rotated. |
|
is_flow: this parameter does not have any effect |
|
Returns: |
|
Tensor: Cropped tensor. |
|
""" |
|
i, j, h, w = self.get_params(x, self.size) |
|
|
|
if self.preserve_mosaicing_pattern: |
|
|
|
if i % 2 == 1: |
|
i = i + 1 |
|
if j % 2 == 1: |
|
j = j + 1 |
|
|
|
return x[:, i:i + h, j:j + w] |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(size={0})'.format(self.size) |
|
|
|
|
|
class RandomRotationFlip(object): |
|
"""Rotate the image by angle. |
|
""" |
|
|
|
def __init__(self, degrees, p_hflip=0.5, p_vflip=0.5): |
|
if isinstance(degrees, numbers.Number): |
|
if degrees < 0: |
|
raise ValueError("If degrees is a single number, it must be positive.") |
|
self.degrees = (-degrees, degrees) |
|
else: |
|
if len(degrees) != 2: |
|
raise ValueError("If degrees is a sequence, it must be of len 2.") |
|
self.degrees = degrees |
|
|
|
self.p_hflip = p_hflip |
|
self.p_vflip = p_vflip |
|
|
|
@staticmethod |
|
def get_params(degrees, p_hflip, p_vflip): |
|
"""Get parameters for ``rotate`` for a random rotation. |
|
Returns: |
|
sequence: params to be passed to ``rotate`` for random rotation. |
|
""" |
|
angle = random.uniform(degrees[0], degrees[1]) |
|
angle_rad = angle * pi / 180.0 |
|
|
|
M_original_transformed = torch.FloatTensor([[cos(angle_rad), -sin(angle_rad), 0], |
|
[sin(angle_rad), cos(angle_rad), 0], |
|
[0, 0, 1]]) |
|
|
|
if random.random() < p_hflip: |
|
M_original_transformed[:, 0] *= -1 |
|
|
|
if random.random() < p_vflip: |
|
M_original_transformed[:, 1] *= -1 |
|
|
|
M_transformed_original = torch.inverse(M_original_transformed) |
|
|
|
M_original_transformed = M_original_transformed[:2, :].unsqueeze(dim=0) |
|
M_transformed_original = M_transformed_original[:2, :].unsqueeze(dim=0) |
|
|
|
return M_original_transformed, M_transformed_original |
|
|
|
def __call__(self, x, is_flow=False): |
|
""" |
|
x: [C x H x W] Tensor to be rotated. |
|
is_flow: if True, x is an [2 x H x W] displacement field, which will also be transformed |
|
Returns: |
|
Tensor: Rotated tensor. |
|
""" |
|
assert(len(x.shape) == 3) |
|
|
|
if is_flow: |
|
assert(x.shape[0] == 2) |
|
|
|
M_original_transformed, M_transformed_original = self.get_params(self.degrees, self.p_hflip, self.p_vflip) |
|
affine_grid = F.affine_grid(M_original_transformed, x.unsqueeze(dim=0).shape, align_corners=False) |
|
transformed = F.grid_sample(x.unsqueeze(dim=0), affine_grid, align_corners=False) |
|
|
|
if is_flow: |
|
|
|
A00 = M_transformed_original[0, 0, 0] |
|
A01 = M_transformed_original[0, 0, 1] |
|
A10 = M_transformed_original[0, 1, 0] |
|
A11 = M_transformed_original[0, 1, 1] |
|
vx = transformed[:, 0, :, :].clone() |
|
vy = transformed[:, 1, :, :].clone() |
|
transformed[:, 0, :, :] = A00 * vx + A01 * vy |
|
transformed[:, 1, :, :] = A10 * vx + A11 * vy |
|
|
|
return transformed.squeeze(dim=0) |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) |
|
format_string += ', p_flip={:.2f}'.format(self.p_hflip) |
|
format_string += ', p_vlip={:.2f}'.format(self.p_vflip) |
|
format_string += ')' |
|
return format_string |
|
|