File size: 1,070 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torchvision import transforms as T
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    from PIL import Image
    BICUBIC = Image.BICUBIC


def clip_transforms(mode='train', img_size=224, flip_prob=0.5):
    assert mode in ['train', 'test', 'val']
    min_size = img_size
    max_size = img_size
    # assert min_size <= max_size


    normalize_transform = T.Normalize(
        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
        )

    if mode == 'train':
        transform = T.Compose(
            [
                T.Resize(max_size, BICUBIC),
                T.RandomCrop(min_size),
                T.RandomHorizontalFlip(flip_prob),
                T.ToTensor(),
                normalize_transform,
            ]
        )
    else:
        transform = T.Compose(
            [
                T.Resize(max_size, BICUBIC),
                T.CenterCrop(min_size),
                T.ToTensor(),
                normalize_transform,
            ]
        )
    return transform