File size: 624 Bytes
ed697ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import numpy as np
import torch
class AgeTransformer(object):
def __init__(self, target_age):
self.target_age = target_age
def __call__(self, img):
img = self.add_aging_channel(img)
return img
def add_aging_channel(self, img):
target_age = self.__get_target_age()
target_age = int(target_age) / 100 # normalize aging amount to be in range [-1,1]
img = torch.cat((img, target_age * torch.ones((1, img.shape[1], img.shape[2]))))
return img
def __get_target_age(self):
if self.target_age == "uniform_random":
return np.random.randint(low=0., high=101, size=1)[0]
else:
return self.target_age
|