|
import random |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as F |
|
from torch.nn.functional import pad |
|
|
|
|
|
class ENDOVIS_18_Transform(): |
|
def __init__(self, config): |
|
self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1,1,1) |
|
self.pixel_std = torch.Tensor([53.395, 57.12, 57.375]).view(-1,1,1) |
|
self.degree = config['data_transforms']['rotation_angle'] |
|
self.saturation = config['data_transforms']['saturation'] |
|
self.brightness = config['data_transforms']['brightness'] |
|
self.img_size = config['data_transforms']['img_size'] |
|
self.resize = transforms.Resize(self.img_size-1, max_size=self.img_size, antialias=True) |
|
|
|
self.data_transforms = config['data_transforms'] |
|
|
|
def __call__(self, img, mask, apply_norm=True, is_train=True): |
|
if is_train: |
|
|
|
if self.data_transforms['use_horizontal_flip']: |
|
p = random.random() |
|
if p<0.5: |
|
img = F.hflip(img) |
|
mask = F.hflip(mask) |
|
|
|
|
|
if self.data_transforms['use_rotation']: |
|
p = random.random() |
|
if p<0.5: |
|
deg = 1+random.choice(list(range(self.degree))) |
|
img = F.rotate(img, angle = deg) |
|
mask = F.rotate(mask, angle=deg) |
|
|
|
|
|
if self.data_transforms['use_saturation']: |
|
p = random.random() |
|
if p<0.2: |
|
img = F.adjust_saturation(img, self.saturation) |
|
|
|
|
|
if self.data_transforms['use_brightness']: |
|
p = random.random() |
|
if p<0.5: |
|
img = F.adjust_brightness(img, self.brightness*random.random()) |
|
|
|
|
|
if self.data_transforms['use_random_crop']: |
|
fallback = 20 |
|
fall_back_ctr = 0 |
|
repeat_flag = True |
|
while(repeat_flag): |
|
fall_back_ctr += 1 |
|
t = transforms.RandomCrop((self.img_size, self.img_size)) |
|
i,j,h,w = t.get_params(img, (self.img_size, self.img_size)) |
|
|
|
|
|
if not mask.any(): |
|
repeat_flag = False |
|
|
|
|
|
if fall_back_ctr >= fallback: |
|
temp1, temp2, temp3 = np.where(mask!=0) |
|
point_of_interest = random.choice(list(range(len(temp2)))) |
|
i = temp2[point_of_interest] - (h//2) |
|
j = temp3[point_of_interest] - (w//2) |
|
repeat_flag = False |
|
|
|
cropped_img = F.crop(img, i, j, h, w) |
|
cropped_mask = F.crop(mask, i, j, h, w) |
|
if cropped_mask.any(): |
|
repeat_flag = False |
|
img = cropped_img |
|
mask = cropped_mask |
|
else: |
|
|
|
b_min = 0 |
|
img = self.resize(img) |
|
mask = self.resize(mask) |
|
|
|
h, w = img.shape[-2:] |
|
padh = self.img_size - h |
|
padw = self.img_size - w |
|
img = pad(img, (0, padw, 0, padh), value=b_min) |
|
mask = pad(mask, (0, padw, 0, padh), value=b_min) |
|
|
|
|
|
|
|
if apply_norm: |
|
b_min=0 |
|
|
|
b_min,b_max = 0, 255 |
|
img = (img - self.data_transforms['a_min']) / (self.data_transforms['a_max'] - self.data_transforms['a_min']) |
|
img = img * (b_max - b_min) + b_min |
|
img = torch.clamp(img,b_min,b_max) |
|
|
|
|
|
img = (img - self.pixel_mean)/self.pixel_std |
|
|
|
return img, mask |