SALT-SAM / AllinonSAM /data_transforms /cholec_8k_transform.py
pythn's picture
Upload with huggingface_hub
4a1f918 verified
import random
import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
from torch.nn.functional import pad
class Cholec_8k_Transform():
def __init__(self, config):
self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1,1,1)
self.pixel_std = torch.Tensor([53.395, 57.12, 57.375]).view(-1,1,1)
self.degree = config['data_transforms']['rotation_angle']
self.saturation = config['data_transforms']['saturation']
self.brightness = config['data_transforms']['brightness']
self.img_size = config['data_transforms']['img_size']
self.resize = transforms.Resize(self.img_size-1, max_size=self.img_size, antialias=True)
self.data_transforms = config['data_transforms']
def __call__(self, img, mask, apply_norm=True, is_train=True):
if is_train:
#flip horizontally with some probability
if self.data_transforms['use_horizontal_flip']:
p = random.random()
if p<0.5:
img = F.hflip(img)
mask = F.hflip(mask)
#rotate with p1 probability
if self.data_transforms['use_rotation']:
p = random.random()
if p<0.5:
deg = 1+random.choice(list(range(self.degree)))
img = F.rotate(img, angle = deg)
mask = F.rotate(mask, angle=deg)
#adjust saturation with some probability
if self.data_transforms['use_saturation']:
p = random.random()
if p<0.2:
img = F.adjust_saturation(img, self.saturation)
#adjust brightness with some probability
if self.data_transforms['use_brightness']:
p = random.random()
if p<0.5:
img = F.adjust_brightness(img, self.brightness*random.random())
#take random crops of img size X img_size such that label is non zero
if self.data_transforms['use_random_crop']:
fallback = 20
fall_back_ctr = 0
repeat_flag = True
while(repeat_flag):
fall_back_ctr += 1
t = transforms.RandomCrop((self.img_size, self.img_size))
i,j,h,w = t.get_params(img, (self.img_size, self.img_size))
#if mask is all zeros, exit the loop
if not mask.any():
repeat_flag = False
#fallback to avoid long loops
if fall_back_ctr >= fallback:
temp1, temp2, temp3 = np.where(mask!=0)
point_of_interest = random.choice(list(range(len(temp2))))
i = temp2[point_of_interest] - (h//2)
j = temp3[point_of_interest] - (w//2)
repeat_flag = False
cropped_img = F.crop(img, i, j, h, w)
cropped_mask = F.crop(mask, i, j, h, w)
if cropped_mask.any():
repeat_flag = False
img = cropped_img
mask = cropped_mask
else:
#if no random crops then perform resizing
b_min = 0
img = self.resize(img)
mask = self.resize(mask)
#pad if necessary
h, w = img.shape[-2:]
padh = self.img_size - h
padw = self.img_size - w
img = pad(img, (0, padw, 0, padh), value=b_min)
mask = pad(mask, (0, padw, 0, padh), value=b_min)
#apply centering based on SAM's expected mean and variance
if apply_norm:
b_min=0
#scale intensities to 0-255
b_min,b_max = 0, 255
img = (img - self.data_transforms['a_min']) / (self.data_transforms['a_max'] - self.data_transforms['a_min'])
img = img * (b_max - b_min) + b_min
img = torch.clamp(img,b_min,b_max)
#center around SAM's expected mean
img = (img - self.pixel_mean)/self.pixel_std
return img, mask