tennant's picture
upload
7b0a1ef
import PIL
import torchvision.transforms as T
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def stage1_transform(img_size=256, is_train=True, scale=0.8):
resize = pair(int(img_size/scale))
t = []
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
if is_train:
t.append(T.RandomCrop(img_size))
t.append(T.RandomHorizontalFlip(p=0.5))
else:
t.append(T.CenterCrop(img_size))
t.append(T.ToTensor())
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
return T.Compose(t)
def stage2_transform(img_size=256, is_train=True, scale=0.8):
resize = pair(int(img_size/scale))
t = []
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
if is_train:
t.append(T.RandomCrop(img_size))
else:
t.append(T.CenterCrop(img_size))
t.append(T.ToTensor())
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
return T.Compose(t)