File size: 3,239 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
data_augmentation    Script  ver: Sep 1st 20:30

dataset structure: ImageNet
image folder dataset is used.
"""

from torchvision import transforms


def data_augmentation(data_augmentation_mode=0, edge_size=384):
    if data_augmentation_mode == 0:  # ROSE + MARS
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomRotation((0, 180)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.CenterCrop(700),  # center area for classification
                transforms.Resize([edge_size, edge_size]),
                transforms.ColorJitter(brightness=0.15, contrast=0.3, saturation=0.3, hue=0.06),
                # HSL shift operation
                transforms.ToTensor()
            ]),
            'val': transforms.Compose([
                transforms.CenterCrop(700),
                transforms.Resize([edge_size, edge_size]),
                transforms.ToTensor()
            ]),
        }
        
    elif data_augmentation_mode == 1:  # Cervical
        data_transforms = {
            'train': transforms.Compose([
                transforms.Resize([edge_size, edge_size]),
                transforms.RandomVerticalFlip(),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.15, contrast=0.3, saturation=0.3, hue=0.06),
                # HSL shift operation
                transforms.ToTensor()
            ]),
            'val': transforms.Compose([
                transforms.Resize([edge_size, edge_size]),
                transforms.ToTensor()
            ]),
        }

    elif data_augmentation_mode == 2:  # warwick
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomRotation((0, 180)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.CenterCrop(360),  # center area for classification
                transforms.Resize([edge_size, edge_size]),
                transforms.ColorJitter(brightness=0.15, contrast=0.3, saturation=0.3, hue=0.06),
                # HSL shift operation
                transforms.ToTensor()
            ]),
            'val': transforms.Compose([
                transforms.CenterCrop(360),
                transforms.Resize([edge_size, edge_size]),
                transforms.ToTensor()
            ]),
        }

    elif data_augmentation_mode == 3:  # for the squre input: just resize
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.Resize([edge_size, edge_size]),
                transforms.ColorJitter(brightness=0.15, contrast=0.3, saturation=0.3, hue=0.06),
                # HSL shift operation
                transforms.ToTensor()
            ]),
            'val': transforms.Compose([
                transforms.Resize([edge_size, edge_size]),
                transforms.ToTensor()
            ]),
        }
    else:
        print('no legal data augmentation is selected')
        return -1
    return data_transforms