PuzzleTuning_VPT / PuzzleTuning /utils /online_augmentations.py
Tianyinus's picture
init submit
edcf5ee verified
"""
Online Augmentations May 23rd 2023 21:30
ref:
CutOut, Mixup, CutMix based on
https://blog.csdn.net/cp1314971/article/details/106612060
"""
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from scipy.special import perm
from torchvision.transforms import Resize
from torchvision.transforms import ToPILImage, ToTensor
from utils.visual_usage import patchify, unpatchify
from utils.fmix import sample_mask, FMixBase # Fmix
# generate random bounding box
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int64(W * cut_rat)
cut_h = np.int64(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def saliency_bbox(img, lam):
size = img.size()
W = size[1]
H = size[2]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# initialize OpenCV's static fine grained saliency detector and
# compute the saliency map
temp_img = img.cpu().numpy().transpose(1, 2, 0)
saliency = cv2.saliency.StaticSaliencyFineGrained_create()
(success, saliencyMap) = saliency.computeSaliency(temp_img)
saliencyMap = (saliencyMap * 255).astype("uint8")
maximum_indices = np.unravel_index(np.argmax(saliencyMap, axis=None), saliencyMap.shape)
x = maximum_indices[0]
y = maximum_indices[1]
bbx1 = np.clip(x - cut_w // 2, 0, W)
bby1 = np.clip(y - cut_h // 2, 0, H)
bbx2 = np.clip(x + cut_w // 2, 0, W)
bby2 = np.clip(y + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
# augmentation methods
class Cutout(object):
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
Cutout augmentation arXiv:1708.04552
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
cutout_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) # get random bbox
cutout_inputs[i, :, bbx1:bbx2, bby1:bby2] = 0
# update the ratio of (area of ori_image on new masked image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
long_label = labels.argmax(dim=1)
# NOTICE cutout use long label and ori_crossentropy instead of soft-label and soft-label_crossentropy
return cutout_inputs, long_label, long_label
class CutMix(object):
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
CutMix augmentation arXiv:1905.04899
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p * (perm(self.batch_size, self.batch_size)
/ (perm(self.batch_size, self.batch_size) -
perm(self.batch_size - 1, self.batch_size - 1)))
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
cutmix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) # get random bbox
cutmix_inputs[i, :, bbx1:bbx2, bby1:bby2] = \
shuffled_inputs[i, :, bbx1:bbx2, bby1:bby2]
# update the ratio of (area of ori_image on new image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return cutmix_inputs, labels, long_label
class Mixup(object):
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
Mixup augmentation arXiv:1710.09412
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p * (perm(self.batch_size, self.batch_size)
/ (perm(self.batch_size, self.batch_size) -
perm(self.batch_size - 1, self.batch_size - 1)))
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
mixup_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
lam_list.append(lam)
mixup_inputs[i] = ori_inputs[i] * lam + shuffled_inputs[i] * (1 - lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return mixup_inputs, labels, long_label
class SaliencyMix(object):
def __init__(self, alpha=1, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
SaliencyMix augmentation arXiv:2006.01791
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
# ori batch_size=128
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
saliencymix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act) or self.alpha <= 0:
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
bbx1, bby1, bbx2, bby2 = saliency_bbox(shuffled_inputs[i], lam) # get random bbox
saliencymix_inputs[i, :, bbx1:bbx2, bby1:bby2] = \
shuffled_inputs[i, :, bbx1:bbx2, bby1:bby2]
# update the ratio of (area of ori_image on new image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return saliencymix_inputs, labels, long_label
class ResizeMix(object):
def __init__(self, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
ResizeMix augmentation arXiv:2012.11101
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
# ori batch_size=512
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, alpha=0.1, beta=0.8, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
resizemix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.uniform(alpha, beta)
# lam = 1 - lam
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) # get random bbox
# resizer by torchvision
torch_resize = Resize([bbx2 - bbx1, bby2 - bby1])
# Tensor -> PIL -> resize -> Tensor
re_pil_image = torch_resize(ToPILImage()(shuffled_inputs[i]))
resizemix_inputs[i, :, bbx1:bbx2, bby1:bby2] = ToTensor()(re_pil_image)
# update the ratio of (area of ori_image on new image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return resizemix_inputs, labels, long_label
class FMix(FMixBase):
def __init__(self, shuffle_p=1.0, class_num=2, batch_size=4, decay_power=3, alpha=1, size=(32, 32),
max_soft=0.0, reformulate=False, device='cpu'):
"""
FMix augmentation arXiv:2002.12047
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param decay_power: decay_power
:param alpha: alpha
:param size: size of patch
:param max_soft: max_soft
:param reformulate: reformulate
:param device: CUDA or CPU
"""
# ori batch_size=128
super().__init__(decay_power, alpha, size, max_soft, reformulate)
self.class_num = class_num
self.batch_size = batch_size
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, alpha=1, act=True):
# Sample mask and generate random permutation
lam, mask = sample_mask(self.alpha, self.decay_power, self.size, self.max_soft, self.reformulate)
mask = torch.from_numpy(mask).float().to(self.device)
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device)
fmix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
x1 = mask * ori_inputs[i]
x2 = (1 - mask) * shuffled_inputs[i]
fmix_inputs[i] = x1 + x2
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
# print('lam:', lam)
return fmix_inputs, labels, long_label
# CellMix
class CellMix(object):
def __init__(self, shuffle_p=1.0, class_num=2, strategy='In-place', group_shuffle_size=-1, device='cpu'):
"""
CellMix augmentation arXiv:2301.11513
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param strategy: 'In-place' or 'Random' to shuffle the relation patches within the batch
:param group_shuffle_size: the size of shuffling group in the batch, -1 to all
:param device: CUDA or CPU
"""
self.p = shuffle_p
self.CLS = class_num # classification category number of the task
self.device = device
self.strategy = strategy # 'In-place' or 'Random'
self.group_shuffle_size = group_shuffle_size # -1 for whole batch
def __call__(self, inputs, labels, fix_position_ratio=0.5, puzzle_patch_size=32, act=True):
"""
Fix-position in-place shuffling
Perform cross-sample random selection to fix some patches in each image of the batch
After selection, the fixed patches are reserved, the rest patches are batch wise
in-place shuffled and then regrouped with the fixed patches.
cross-sample selection is done by argsort random noise in dim 1 and apply to all image within the batch.
in-place batch-wise shuffle operation is done by argsort random noise in dim 0.
grouped-in-place batch-wise shuffle operation is done by argsort random noise in the batch dimension
:param inputs: input image tensor, size of [B, 3, H, W],
:param labels:
:param fix_position_ratio: float ratio of the least remaining part of patches
:param puzzle_patch_size: int patch size of shuffle
:param act: set to be False to force not triggering CellMix in validation, set to True to trigger by chance p
output: x, soft_label, long_label
x : [B, 3, H, W] re-grouped image after cellmix augmentation
soft_label : [B, CLS], soft-label of the class distribution
long_label : [B] hard long-label for general discribe
"""
if np.random.randint(0, 101) > 100 * self.p or (not act):
soft_label = torch.eye(self.CLS).to(self.device)[labels, :] # one-hot hard label
return inputs, soft_label, labels
# Break img into puzzle patches with the size of puzzle_patch_size [B, num_patches, D]
inputs = patchify(inputs, puzzle_patch_size)
B, num_patches, D = inputs.shape
# generate the persudo-mask: in cls dim only the k dim is
mask = torch.zeros([B, num_patches, self.CLS], device=inputs.device, requires_grad=False) # no grad
# mask of patches: (B, num_patches, cls) (cls)=[0,mask_area,0,....]
# transform to persudo-mask
B_idx = range(B)
mask[B_idx, :, labels] = 1
# num of fix_position puzzle patches
len_fix_position = int(num_patches * fix_position_ratio)
# create a noise tensor to prepare shuffle idx of puzzle patches
noise = torch.rand(1, num_patches, device=self.device)
noise = torch.repeat_interleave(noise, repeats=B, dim=0)
# based on the batch sequence's shape, the noise tensor get a series idx matrix by sort
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
# sort the idx matrix again, we can obtain the original location idx matrix before assignment
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_fix = ids_shuffle[:, :len_fix_position] # [B,num_patches] -> [B,fix_patches]
ids_puzzle = ids_shuffle[:, len_fix_position:] # [B,num_patches] -> [B,puzzle_patches]
# set puzzle patch
# ids_?.unsqueeze(-1).repeat(1, 1, D)
# [B,?_patches] -> [B,?_patches,1] (at each place with the idx of ori patch) -> [B,?_patches,D]
# torch.gather to select patche groups x_fixed of [B,fix_patches,D] and x_puzzle of [B,puzzle_patches,D]
x_fixed = torch.gather(inputs, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, D))
x_puzzle = torch.gather(inputs, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, D))
mask_fixed = torch.gather(mask, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, self.CLS))
mask_puzzle = torch.gather(mask, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, self.CLS))
if self.strategy == 'In-place' or self.strategy == 'Random':
# the In-place strategy shuffles the relation patches within their location, among the batch index
B, num_shuffle_patches, D = x_puzzle.shape
# create a noise tensor to prepare shuffle idx of puzzle patches
# [B, num_shuffle_patches] noise in [0, 1]
noise = torch.rand(B, num_shuffle_patches, device=self.device)
if self.group_shuffle_size == -1 or self.group_shuffle_size == B: # CellMix-Split
# sort the noise matrix, obtain a index assignment for shuffle,
# shuffle dim 0 of entire noise (among all the batch)
in_place_shuffle_indices = torch.argsort(noise, dim=0)
else: # CellMix-Group
assert B > self.group_shuffle_size > 0 and B % self.group_shuffle_size == 0
grouped_indices_list = []
for group_idx in range(B // self.group_shuffle_size):
# group the noise by self.group_shuffle_size: [group_shuffle_size,N]
grouped_noise = noise[group_idx * self.group_shuffle_size:
group_idx * self.group_shuffle_size + self.group_shuffle_size, :]
# sort each grouped_noise matrix, obtain a index assignment for shuffle,
# now the shuffle dim is 0 (among the batch within the group)
grouped_indices = torch.argsort(grouped_noise, dim=0)
# put grouped_noise matrix into the list
grouped_indices_list.append(grouped_indices + self.group_shuffle_size * group_idx)
# stack(cat) the group indices(from list) back to tensor
in_place_shuffle_indices = torch.cat(grouped_indices_list, dim=0)
# torch.gather to achieve shuffle (taking all the idx base on a shuffled indices)
x_puzzle = torch.gather(x_puzzle, dim=0, index=in_place_shuffle_indices.unsqueeze(-1).repeat(1, 1, D))
mask_puzzle = torch.gather(mask_puzzle, dim=0,
index=in_place_shuffle_indices.unsqueeze(-1).repeat(1, 1, self.CLS))
else:
print('not a valid CellMix strategy')
# pack up all puzzle patches
inputs = torch.cat([x_fixed, x_puzzle], dim=1)
mask = torch.cat([mask_fixed, mask_puzzle], dim=1)
# unshuffle to restore the fixed positions
inputs = torch.gather(inputs, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
# torch.gather to generate restored binary mask
mask = torch.gather(mask, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, self.CLS))
# CellMix random strategy randomly shuffle the image patches (after cellmix in-place shuffle)
if self.strategy == 'Random':
B, num_patches, D = inputs.shape
# create a noise tensor to prepare shuffle idx of puzzle patches
noise = torch.rand(B, num_patches, device=self.device) # [num_patches,B] noise in [0, 1]
# sort the noise matrix, obtain a index assignment for shuffle, now the shuffle dim is 1 (with the batch)
all_shuffle_indices = torch.argsort(noise, dim=1)
# ids_shuffle shape of [B,N], in N is idx
# torch.gather to shuffle
inputs = torch.gather(inputs, dim=1, index=all_shuffle_indices.unsqueeze(-1).repeat(1, 1, D))
# no need to torch the mask, because its patch-wise shuffle within each sample
else: # when strategy == 'In-place'
pass
# unpatchify to obtain puzzle images and their mask
inputs = unpatchify(inputs, puzzle_patch_size) # restore to image size:B,3,224,224/ B,3,384,384
# transform soft-mask to soft-label
# calaculate a composed label with a conjugate design
# [B, num_patches, CLS]->(B, CLS)
soft_label = mask.sum(dim=1) # (B, CLS)
soft_label = soft_label / num_patches
# long_label, as a data-augmentation requirement
long_label = soft_label.argmax(dim=1)
return inputs, soft_label, long_label
# ask func
def get_online_augmentation(augmentation_name, p=0.5, class_num=2, batch_size=4, edge_size=224, device='cpu'):
"""
:param augmentation_name: name of data-augmentation method
:param p: chance of triggering
:param class_num: classification task num
:param batch_size: batch size
:param edge_size: edge size of img
:param device: cpu or cuda
其中augmentation_name, class_num, batch_size, edge_size必须提供
"""
if augmentation_name == 'CellMix-Group':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=2,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Group4':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=4,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Split':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=-1,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Random':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=2,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Random4':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=4,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Self':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=1,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-All':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=-1,
device=device)
return Augmentation
elif augmentation_name == 'Cutout':
Augmentation = Cutout(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'CutMix':
Augmentation = CutMix(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'Mixup':
Augmentation = Mixup(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'SaliencyMix':
Augmentation = SaliencyMix(alpha=1, shuffle_p=p, class_num=class_num, batch_size=batch_size,
device=device) # alpha实际为源代码中beta
return Augmentation
elif augmentation_name == 'ResizeMix':
Augmentation = ResizeMix(shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'FMix':
# FMIX p=1.0 beacuse the chance of trigger is determined inside its own design
Augmentation = FMix(shuffle_p=1.0, class_num=class_num, batch_size=batch_size,
size=(edge_size, edge_size), device=device)
return Augmentation
elif augmentation_name == 'PuzzleMix':
return None
# fixme: all related parts have been taken out seperately
# Augmentation = PuzzleMix(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
# return Augmentation
elif augmentation_name == 'CoMix':
# TODO CoMix
return None
elif augmentation_name == 'RandomMix':
# TODO RandomMix
return None
else:
print('no valid counterparts augmentation selected')
return None
if __name__ == '__main__':
'''
Augmentation = get_online_augmentation('CellMix-Split', p=0.5, class_num=2)
output, labels, GT_labels = Augmentation(x, label, fix_position_ratio=0.5, puzzle_patch_size=32, act=True)
print(labels, GT_labels)
'''
x = torch.load("./temp-tensors/warwick.pt")
# print(x.shape)
label = torch.load("./temp-tensors/warwick_labels.pt")
# print(label)
# Augmentation = get_online_augmentation('ResizeMix', p=0.5, class_num=2)
# output, labels, GT_labels = Augmentation(x, label, act=True)
Augmentation = get_online_augmentation('CellMix-Group', p=1, class_num=2)
output, labels, GT_labels = Augmentation(x, label, fix_position_ratio=0.5, puzzle_patch_size=32, act=True)
print(labels, GT_labels)
composed_img = ToPILImage()(output[0])
composed_img.show()
composed_img = ToPILImage()(output[1])
composed_img.show()
composed_img = ToPILImage()(output[2])
composed_img.show()
composed_img = ToPILImage()(output[3])
composed_img.show()