File size: 1,055 Bytes
7b0a1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import PIL
import torchvision.transforms as T

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def stage1_transform(img_size=256, is_train=True, scale=0.8):
    
    resize = pair(int(img_size/scale))
    t = []
    t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
    if is_train:
        t.append(T.RandomCrop(img_size))
        t.append(T.RandomHorizontalFlip(p=0.5))
    else:
        t.append(T.CenterCrop(img_size))
        
    t.append(T.ToTensor())
    t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
    
    return T.Compose(t)
        
def stage2_transform(img_size=256, is_train=True, scale=0.8):
    resize = pair(int(img_size/scale))
    t = []
    t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
    if is_train:
        t.append(T.RandomCrop(img_size))
    else:
        t.append(T.CenterCrop(img_size))
        
    t.append(T.ToTensor())
    t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
    
    return T.Compose(t)