File size: 624 Bytes
ed697ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
import torch


class AgeTransformer(object):

	def __init__(self, target_age):
		self.target_age = target_age

	def __call__(self, img):
		img = self.add_aging_channel(img)
		return img

	def add_aging_channel(self, img):
		target_age = self.__get_target_age()
		target_age = int(target_age) / 100  # normalize aging amount to be in range [-1,1]
		img = torch.cat((img, target_age * torch.ones((1, img.shape[1], img.shape[2]))))
		return img

	def __get_target_age(self):
		if self.target_age == "uniform_random":
			return np.random.randint(low=0., high=101, size=1)[0]
		else:
			return self.target_age