Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import kornia.augmentation as K | |
class ImageAugmentations(nn.Module): | |
def __init__(self, output_size, augmentations_number, p=0.7): | |
super().__init__() | |
self.output_size = output_size | |
self.augmentations_number = augmentations_number | |
self.augmentations = nn.Sequential( | |
K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore | |
K.RandomPerspective(0.7, p=p), | |
) | |
self.avg_pool = nn.AdaptiveAvgPool2d((self.output_size, self.output_size)) | |
def forward(self, input): | |
"""Extents the input batch with augmentations | |
If the input is consists of images [I1, I2] the extended augmented output | |
will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...] | |
Args: | |
input ([type]): input batch of shape [batch, C, H, W] | |
Returns: | |
updated batch: of shape [batch * augmentations_number, C, H, W] | |
""" | |
# We want to multiply the number of images in the batch in contrast to regular augmantations | |
# that do not change the number of samples in the batch) | |
resized_images = self.avg_pool(input) | |
resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1)) | |
batch_size = input.shape[0] | |
# We want at least one non augmented image | |
non_augmented_batch = resized_images[:batch_size] | |
augmented_batch = self.augmentations(resized_images[batch_size:]) | |
updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0) | |
return updated_batch | |