import numpy as np | |
import torch | |
class AgeTransformer(object): | |
def __init__(self, target_age): | |
self.target_age = target_age | |
def __call__(self, img): | |
img = self.add_aging_channel(img) | |
return img | |
def add_aging_channel(self, img): | |
target_age = self.__get_target_age() | |
target_age = int(target_age) / 100 # normalize aging amount to be in range [-1,1] | |
img = torch.cat((img, target_age * torch.ones((1, img.shape[1], img.shape[2])))) | |
return img | |
def __get_target_age(self): | |
if self.target_age == "uniform_random": | |
return np.random.randint(low=0., high=101, size=1)[0] | |
else: | |
return self.target_age | |