|
from abc import abstractmethod |
|
import torchvision.transforms as transforms |
|
|
|
|
|
class TransformsConfig(object): |
|
|
|
def __init__(self, opts): |
|
self.opts = opts |
|
|
|
@abstractmethod |
|
def get_transforms(self): |
|
pass |
|
|
|
|
|
class AgingTransforms(TransformsConfig): |
|
|
|
def __init__(self, opts): |
|
super(AgingTransforms, self).__init__(opts) |
|
|
|
def get_transforms(self): |
|
transforms_dict = { |
|
'transform_gt_train': transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.RandomHorizontalFlip(0.5), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), |
|
'transform_source': None, |
|
'transform_test': transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), |
|
'transform_inference': transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
} |
|
return transforms_dict |
|
|