File size: 5,219 Bytes
4a1f918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import random
import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
from torch.nn.functional import pad


class GLAS_Transform():
    def __init__(self, config):
        self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1,1,1)
        self.pixel_std = torch.Tensor([53.395, 57.12, 57.375]).view(-1,1,1)
        self.degree = config['data_transforms']['rotation_angle']
        self.saturation = config['data_transforms']['saturation']
        self.brightness = config['data_transforms']['brightness']
        self.img_size = config['data_transforms']['img_size']
        self.resize = transforms.Resize(self.img_size-1, max_size=self.img_size, antialias=True)

        self.data_transforms = config['data_transforms']

    def __call__(self, img, mask, apply_norm=True, is_train=True):
        if is_train:
            #flip horizontally with some probability
            if self.data_transforms['use_horizontal_flip']:
                p = random.random()
                if p<0.5:
                    img = F.hflip(img)
                    mask = F.hflip(mask)

            #rotate with p1 probability
            if self.data_transforms['use_rotation']:
                p = random.random()
                if p<0.5:
                    deg = 1+random.choice(list(range(self.degree)))
                    img = F.rotate(img, angle = deg)
                    mask = F.rotate(mask, angle=deg)

            #adjust saturation with some probability
            if self.data_transforms['use_saturation']:
                p = random.random()
                if p<0.2:
                    img = F.adjust_saturation(img, self.saturation)
            
            #adjust brightness with some probability
            if self.data_transforms['use_brightness']:
                p = random.random()
                if p<0.5:
                    img = F.adjust_brightness(img, self.brightness*max(0.5,random.random()))

            #adjust color jitter with some probability
            if self.data_transforms['use_cjitter']:
                p = random.random()
                if p<0.5:
                    brightness = random.uniform(0,0.2)
                    contrast = random.uniform(0,0.2)
                    saturation = random.uniform(0,0.2)
                    hue = random.uniform(0,0.1)
                    img = F.adjust_brightness(img, brightness_factor=brightness)
                    img = F.adjust_contrast(img, contrast_factor=contrast)
                    img = F.adjust_saturation(img, saturation_factor=saturation)
                    img = F.adjust_hue(img, hue_factor=hue)

            #affine transforms with some probability
            if self.data_transforms['use_affine']:
                p = random.random()
                if p<0.5:
                    scale = random.uniform(0.9,1)
                    img = F.affine(img, translate=[5,5], scale=scale, angle=5, shear=0)
                    mask = F.affine(img, translate=[5,5], scale=scale, angle=5, shear=0)

        #take random crops of img size X img_size such that label is non zero
        if self.data_transforms['use_random_crop']:
            fallback = 20
            fall_back_ctr = 0
            repeat_flag = True
            while(repeat_flag):
                fall_back_ctr += 1                    
                t = transforms.RandomCrop((self.img_size, self.img_size))
                i,j,h,w = t.get_params(img, (self.img_size, self.img_size))
                
                #if mask is all zeros, exit the loop
                if not mask.any():
                    repeat_flag = False
                
                #fallback to avoid long loops
                if fall_back_ctr >= fallback:
                    temp1, temp2, temp3 = np.where(mask!=0)
                    point_of_interest = random.choice(list(range(len(temp2))))
                    i = temp2[point_of_interest] - (h//2)
                    j = temp3[point_of_interest] - (w//2)
                    repeat_flag = False

                cropped_img = F.crop(img, i, j, h, w)
                cropped_mask = F.crop(mask, i, j, h, w)
                if cropped_mask.any():
                    repeat_flag = False
            img = cropped_img
            mask = cropped_mask
        else:
            #if no random crops then perform resizing
            b_min = 0
            img = self.resize(img)
            mask = self.resize(mask)
            #pad if necessary
            h, w = img.shape[-2:]
            padh = self.img_size - h
            padw = self.img_size - w
            img = pad(img, (0, padw, 0, padh), value=b_min)
            mask = pad(mask, (0, padw, 0, padh), value=b_min)


        #apply centering based on SAM's expected mean and variance
        if apply_norm:
            b_min=0
            #scale intensities to 0-255
            b_min,b_max = 0, 255
            img = (img - img.min()) / (img.max() - img.min())
            img = img * (b_max - b_min) + b_min
            img = torch.clamp(img,b_min,b_max)

            #center around SAM's expected mean
            img = (img - self.pixel_mean)/self.pixel_std
            
        return img, mask