import torch from torch import Tensor, nn import torch.nn.functional as F import torchvision from torchvision import transforms from PIL import Image import numpy as np import matplotlib.pyplot as plt from sklearn.decomposition import PCA class RandomAffineAndRetMat(torch.nn.Module): def __init__( self, degrees, translate=None, scale=None, shear=None, interpolation=torchvision.transforms.InterpolationMode.NEAREST, fill=0, center=None, ): super().__init__() self.degrees = degrees self.translate = translate self.scale = scale self.shear = shear self.interpolation = interpolation self.fill = fill self.center = center def forward(self, img): """ img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: Affine transformed image. """ fill = self.fill if isinstance(img, Tensor): if isinstance(fill, (int, float)): fill = [float(fill)] * transforms.functional.get_image_num_channels(img) else: fill = [float(f) for f in fill] img_size = transforms.functional.get_image_size(img) ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center) affine_matrix = self.get_affine_matrix_from_params(ret) return transformed_image, affine_matrix def get_affine_matrix_from_params(self, params): degrees, translate, scale, shear = params degrees = torch.tensor(degrees) shear = torch.tensor(shear) # パラメータを変換行列に変換 rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0], [torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0], [0, 0, 1]]) translation_matrix = torch.tensor([[1, 0, translate[0]], [0, 1, translate[1]], [0, 0, 1]]).to(torch.float32) scaling_matrix = torch.tensor([[scale, 0, 0], [0, scale, 0], [0, 0, 1]]) shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0], [-torch.tan(torch.deg2rad(shear[1])), 1, 0], [0, 0, 1]]) # 変換行列を合成 affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix) return affine_matrix