File size: 2,971 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from torchvision import transforms
import torch


def one_d_image_train_aug(to_3_channels=False):
    mean, std = (0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)
    return transforms.Compose([
        transforms.Resize(32),
        # transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Lambda((lambda x: torch.cat([x] * 3)) if to_3_channels else (lambda x: x)),
        transforms.Normalize(mean, std)
    ])


def one_d_image_test_aug(to_3_channels=False):
    mean, std = (0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)
    return transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Lambda((lambda x: torch.cat([x] * 3)) if to_3_channels else (lambda x: x)),
        transforms.Normalize(mean, std)
    ])


def cifar_like_image_train_aug(mean=None, std=None):
    if mean is None:
        mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    return transforms.Compose([
        transforms.Resize(40), # NOTE: this is critical!!! or you may crop a small part of an image
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])


def cifar_like_image_test_aug(mean=None, std=None):
    if mean is None:
        mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    return transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

def imagenet_like_image_train_aug():
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    return transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop((224, 224), padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])


def imagenet_like_image_test_aug():
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])


def cityscapes_like_image_train_aug():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
def cityscapes_like_image_test_aug():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
def cityscapes_like_label_aug():
    import numpy as np
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
    ])


def pil_image_to_tensor(img_size=224):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)), 
        transforms.ToTensor()
    ])