File size: 5,218 Bytes
4a1f918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import random
import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
from torch.nn.functional import pad
class ATR_Transform():
def __init__(self, config):
self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1,1,1)
self.pixel_std = torch.Tensor([53.395, 57.12, 57.375]).view(-1,1,1)
self.degree = config['data_transforms']['rotation_angle']
self.saturation = config['data_transforms']['saturation']
self.brightness = config['data_transforms']['brightness']
self.img_size = config['data_transforms']['img_size']
self.resize = transforms.Resize(self.img_size-1, max_size=self.img_size, antialias=True)
self.data_transforms = config['data_transforms']
def __call__(self, img, mask, apply_norm=True, is_train=True):
if is_train:
#flip horizontally with some probability
if self.data_transforms['use_horizontal_flip']:
p = random.random()
if p<0.5:
img = F.hflip(img)
mask = F.hflip(mask)
#rotate with p1 probability
if self.data_transforms['use_rotation']:
p = random.random()
if p<0.5:
deg = 1+random.choice(list(range(self.degree)))
img = F.rotate(img, angle = deg)
mask = F.rotate(mask, angle=deg)
#adjust saturation with some probability
if self.data_transforms['use_saturation']:
p = random.random()
if p<0.2:
img = F.adjust_saturation(img, self.saturation)
#adjust brightness with some probability
if self.data_transforms['use_brightness']:
p = random.random()
if p<0.5:
img = F.adjust_brightness(img, self.brightness*max(0.5,random.random()))
#adjust color jitter with some probability
if self.data_transforms['use_cjitter']:
p = random.random()
if p<0.5:
brightness = random.uniform(0,0.2)
contrast = random.uniform(0,0.2)
saturation = random.uniform(0,0.2)
hue = random.uniform(0,0.1)
img = F.adjust_brightness(img, brightness_factor=brightness)
img = F.adjust_contrast(img, contrast_factor=contrast)
img = F.adjust_saturation(img, saturation_factor=saturation)
img = F.adjust_hue(img, hue_factor=hue)
#affine transforms with some probability
if self.data_transforms['use_affine']:
p = random.random()
if p<0.5:
scale = random.uniform(0.9,1)
img = F.affine(img, translate=[5,5], scale=scale, angle=5, shear=0)
mask = F.affine(img, translate=[5,5], scale=scale, angle=5, shear=0)
#take random crops of img size X img_size such that label is non zero
if self.data_transforms['use_random_crop']:
fallback = 20
fall_back_ctr = 0
repeat_flag = True
while(repeat_flag):
fall_back_ctr += 1
t = transforms.RandomCrop((self.img_size, self.img_size))
i,j,h,w = t.get_params(img, (self.img_size, self.img_size))
#if mask is all zeros, exit the loop
if not mask.any():
repeat_flag = False
#fallback to avoid long loops
if fall_back_ctr >= fallback:
temp1, temp2, temp3 = np.where(mask!=0)
point_of_interest = random.choice(list(range(len(temp2))))
i = temp2[point_of_interest] - (h//2)
j = temp3[point_of_interest] - (w//2)
repeat_flag = False
cropped_img = F.crop(img, i, j, h, w)
cropped_mask = F.crop(mask, i, j, h, w)
if cropped_mask.any():
repeat_flag = False
img = cropped_img
mask = cropped_mask
else:
#if no random crops then perform resizing
b_min = 0
img = self.resize(img)
mask = self.resize(mask)
#pad if necessary
h, w = img.shape[-2:]
padh = self.img_size - h
padw = self.img_size - w
img = pad(img, (0, padw, 0, padh), value=b_min)
mask = pad(mask, (0, padw, 0, padh), value=b_min)
#apply centering based on SAM's expected mean and variance
if apply_norm:
b_min=0
#scale intensities to 0-255
b_min,b_max = 0, 255
img = (img - img.min()) / (img.max() - img.min())
img = img * (b_max - b_min) + b_min
img = torch.clamp(img,b_min,b_max)
#center around SAM's expected mean
img = (img - self.pixel_mean)/self.pixel_std
return img, mask |