|
""" |
|
Online Augmentations May 23rd 2023 21:30 |
|
ref: |
|
CutOut, Mixup, CutMix based on |
|
https://blog.csdn.net/cp1314971/article/details/106612060 |
|
""" |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from scipy.special import perm |
|
from torchvision.transforms import Resize |
|
from torchvision.transforms import ToPILImage, ToTensor |
|
|
|
from utils.visual_usage import patchify, unpatchify |
|
from utils.fmix import sample_mask, FMixBase |
|
|
|
|
|
|
|
def rand_bbox(size, lam): |
|
W = size[2] |
|
H = size[3] |
|
cut_rat = np.sqrt(1. - lam) |
|
cut_w = np.int64(W * cut_rat) |
|
cut_h = np.int64(H * cut_rat) |
|
|
|
|
|
cx = np.random.randint(W) |
|
cy = np.random.randint(H) |
|
|
|
bbx1 = np.clip(cx - cut_w // 2, 0, W) |
|
bby1 = np.clip(cy - cut_h // 2, 0, H) |
|
bbx2 = np.clip(cx + cut_w // 2, 0, W) |
|
bby2 = np.clip(cy + cut_h // 2, 0, H) |
|
|
|
return bbx1, bby1, bbx2, bby2 |
|
|
|
|
|
def saliency_bbox(img, lam): |
|
size = img.size() |
|
W = size[1] |
|
H = size[2] |
|
cut_rat = np.sqrt(1. - lam) |
|
cut_w = np.int(W * cut_rat) |
|
cut_h = np.int(H * cut_rat) |
|
|
|
|
|
|
|
temp_img = img.cpu().numpy().transpose(1, 2, 0) |
|
saliency = cv2.saliency.StaticSaliencyFineGrained_create() |
|
(success, saliencyMap) = saliency.computeSaliency(temp_img) |
|
saliencyMap = (saliencyMap * 255).astype("uint8") |
|
|
|
maximum_indices = np.unravel_index(np.argmax(saliencyMap, axis=None), saliencyMap.shape) |
|
x = maximum_indices[0] |
|
y = maximum_indices[1] |
|
|
|
bbx1 = np.clip(x - cut_w // 2, 0, W) |
|
bby1 = np.clip(y - cut_h // 2, 0, H) |
|
bbx2 = np.clip(x + cut_w // 2, 0, W) |
|
bby2 = np.clip(y + cut_h // 2, 0, H) |
|
|
|
return bbx1, bby1, bbx2, bby2 |
|
|
|
|
|
|
|
class Cutout(object): |
|
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'): |
|
""" |
|
Cutout augmentation arXiv:1708.04552 |
|
:param alpha: alpha |
|
:param shuffle_p: chance of trigger augmentation |
|
:param class_num: number of classification categories |
|
:param batch_size: batch_size of training |
|
:param device: CUDA or CPU |
|
""" |
|
self.alpha = alpha |
|
self.class_num = class_num |
|
self.batch_size = batch_size |
|
self.p = shuffle_p |
|
self.device = torch.device(device) |
|
|
|
def __call__(self, inputs, labels, act=True): |
|
labels = torch.eye(self.class_num).to(self.device)[labels, :] |
|
ori_inputs = inputs.clone().to(self.device) |
|
cutout_inputs = inputs.clone().to(self.device) |
|
lam_list = [] |
|
|
|
for i in range(self.batch_size): |
|
|
|
if np.random.randint(0, 101) > 100 * self.p or (not act): |
|
|
|
lam_list.append(-1) |
|
continue |
|
|
|
lam = np.random.beta(self.alpha, self.alpha) |
|
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) |
|
|
|
cutout_inputs[i, :, bbx1:bbx2, bby1:bby2] = 0 |
|
|
|
|
|
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3])) |
|
lam_list.append(lam) |
|
|
|
long_label = labels.argmax(dim=1) |
|
|
|
|
|
return cutout_inputs, long_label, long_label |
|
|
|
|
|
class CutMix(object): |
|
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'): |
|
""" |
|
CutMix augmentation arXiv:1905.04899 |
|
:param alpha: alpha |
|
:param shuffle_p: chance of trigger augmentation |
|
:param class_num: number of classification categories |
|
:param batch_size: batch_size of training |
|
:param device: CUDA or CPU |
|
""" |
|
self.alpha = alpha |
|
self.class_num = class_num |
|
self.batch_size = batch_size |
|
|
|
|
|
self.p = shuffle_p * (perm(self.batch_size, self.batch_size) |
|
/ (perm(self.batch_size, self.batch_size) - |
|
perm(self.batch_size - 1, self.batch_size - 1))) |
|
self.device = torch.device(device) |
|
|
|
def __call__(self, inputs, labels, act=True): |
|
|
|
labels = torch.eye(self.class_num).to(self.device)[labels, :] |
|
ori_inputs = inputs.clone().to(self.device) |
|
cutmix_inputs = inputs.clone().to(self.device) |
|
lam_list = [] |
|
indices = torch.randperm(self.batch_size, device=self.device) |
|
shuffled_inputs = inputs[indices].to(self.device) |
|
shuffled_labels = labels[indices].to(self.device) |
|
|
|
for i in range(self.batch_size): |
|
|
|
if np.random.randint(0, 101) > 100 * self.p or (not act): |
|
|
|
lam_list.append(-1) |
|
continue |
|
|
|
lam = np.random.beta(self.alpha, self.alpha) |
|
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) |
|
|
|
cutmix_inputs[i, :, bbx1:bbx2, bby1:bby2] = \ |
|
shuffled_inputs[i, :, bbx1:bbx2, bby1:bby2] |
|
|
|
|
|
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3])) |
|
lam_list.append(lam) |
|
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam) |
|
|
|
long_label = labels.argmax(dim=1) |
|
return cutmix_inputs, labels, long_label |
|
|
|
|
|
class Mixup(object): |
|
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'): |
|
""" |
|
Mixup augmentation arXiv:1710.09412 |
|
:param alpha: alpha |
|
:param shuffle_p: chance of trigger augmentation |
|
:param class_num: number of classification categories |
|
:param batch_size: batch_size of training |
|
:param device: CUDA or CPU |
|
""" |
|
self.alpha = alpha |
|
self.class_num = class_num |
|
self.batch_size = batch_size |
|
|
|
self.p = shuffle_p * (perm(self.batch_size, self.batch_size) |
|
/ (perm(self.batch_size, self.batch_size) - |
|
perm(self.batch_size - 1, self.batch_size - 1))) |
|
self.device = torch.device(device) |
|
|
|
def __call__(self, inputs, labels, act=True): |
|
labels = torch.eye(self.class_num).to(self.device)[labels, :] |
|
ori_inputs = inputs.clone().to(self.device) |
|
mixup_inputs = inputs.clone().to(self.device) |
|
lam_list = [] |
|
indices = torch.randperm(self.batch_size, device=self.device) |
|
shuffled_inputs = inputs[indices].to(self.device) |
|
shuffled_labels = labels[indices].to(self.device) |
|
|
|
for i in range(self.batch_size): |
|
if np.random.randint(0, 101) > 100 * self.p or (not act): |
|
|
|
lam_list.append(-1) |
|
continue |
|
|
|
lam = np.random.beta(self.alpha, self.alpha) |
|
lam_list.append(lam) |
|
mixup_inputs[i] = ori_inputs[i] * lam + shuffled_inputs[i] * (1 - lam) |
|
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam) |
|
|
|
long_label = labels.argmax(dim=1) |
|
return mixup_inputs, labels, long_label |
|
|
|
|
|
class SaliencyMix(object): |
|
def __init__(self, alpha=1, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'): |
|
""" |
|
SaliencyMix augmentation arXiv:2006.01791 |
|
:param alpha: alpha |
|
:param shuffle_p: chance of trigger augmentation |
|
:param class_num: number of classification categories |
|
:param batch_size: batch_size of training |
|
:param device: CUDA or CPU |
|
""" |
|
|
|
self.alpha = alpha |
|
self.class_num = class_num |
|
self.batch_size = batch_size |
|
|
|
self.p = shuffle_p |
|
self.device = torch.device(device) |
|
|
|
def __call__(self, inputs, labels, act=True): |
|
labels = torch.eye(self.class_num).to(self.device)[labels, :] |
|
ori_inputs = inputs.clone().to(self.device) |
|
saliencymix_inputs = inputs.clone().to(self.device) |
|
lam_list = [] |
|
indices = torch.randperm(self.batch_size, device=self.device) |
|
shuffled_inputs = inputs[indices].to(self.device) |
|
shuffled_labels = labels[indices].to(self.device) |
|
|
|
for i in range(self.batch_size): |
|
if np.random.randint(0, 101) > 100 * self.p or (not act) or self.alpha <= 0: |
|
|
|
lam_list.append(-1) |
|
continue |
|
|
|
lam = np.random.beta(self.alpha, self.alpha) |
|
bbx1, bby1, bbx2, bby2 = saliency_bbox(shuffled_inputs[i], lam) |
|
|
|
saliencymix_inputs[i, :, bbx1:bbx2, bby1:bby2] = \ |
|
shuffled_inputs[i, :, bbx1:bbx2, bby1:bby2] |
|
|
|
|
|
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3])) |
|
lam_list.append(lam) |
|
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam) |
|
|
|
long_label = labels.argmax(dim=1) |
|
return saliencymix_inputs, labels, long_label |
|
|
|
|
|
class ResizeMix(object): |
|
def __init__(self, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'): |
|
""" |
|
ResizeMix augmentation arXiv:2012.11101 |
|
:param shuffle_p: chance of trigger augmentation |
|
:param class_num: number of classification categories |
|
:param batch_size: batch_size of training |
|
:param device: CUDA or CPU |
|
""" |
|
|
|
self.class_num = class_num |
|
self.batch_size = batch_size |
|
|
|
self.p = shuffle_p |
|
self.device = torch.device(device) |
|
|
|
def __call__(self, inputs, labels, alpha=0.1, beta=0.8, act=True): |
|
labels = torch.eye(self.class_num).to(self.device)[labels, :] |
|
ori_inputs = inputs.clone().to(self.device) |
|
resizemix_inputs = inputs.clone().to(self.device) |
|
lam_list = [] |
|
indices = torch.randperm(self.batch_size, device=self.device) |
|
shuffled_inputs = inputs[indices].to(self.device) |
|
shuffled_labels = labels[indices].to(self.device) |
|
|
|
for i in range(self.batch_size): |
|
if np.random.randint(0, 101) > 100 * self.p or (not act): |
|
|
|
lam_list.append(-1) |
|
continue |
|
|
|
lam = np.random.uniform(alpha, beta) |
|
|
|
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) |
|
|
|
|
|
torch_resize = Resize([bbx2 - bbx1, bby2 - bby1]) |
|
|
|
|
|
re_pil_image = torch_resize(ToPILImage()(shuffled_inputs[i])) |
|
resizemix_inputs[i, :, bbx1:bbx2, bby1:bby2] = ToTensor()(re_pil_image) |
|
|
|
|
|
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3])) |
|
lam_list.append(lam) |
|
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam) |
|
|
|
long_label = labels.argmax(dim=1) |
|
return resizemix_inputs, labels, long_label |
|
|
|
|
|
class FMix(FMixBase): |
|
|
|
def __init__(self, shuffle_p=1.0, class_num=2, batch_size=4, decay_power=3, alpha=1, size=(32, 32), |
|
max_soft=0.0, reformulate=False, device='cpu'): |
|
""" |
|
FMix augmentation arXiv:2002.12047 |
|
:param shuffle_p: chance of trigger augmentation |
|
:param class_num: number of classification categories |
|
:param batch_size: batch_size of training |
|
|
|
:param decay_power: decay_power |
|
:param alpha: alpha |
|
:param size: size of patch |
|
:param max_soft: max_soft |
|
:param reformulate: reformulate |
|
|
|
:param device: CUDA or CPU |
|
""" |
|
|
|
super().__init__(decay_power, alpha, size, max_soft, reformulate) |
|
self.class_num = class_num |
|
self.batch_size = batch_size |
|
self.p = shuffle_p |
|
self.device = torch.device(device) |
|
|
|
def __call__(self, inputs, labels, alpha=1, act=True): |
|
|
|
lam, mask = sample_mask(self.alpha, self.decay_power, self.size, self.max_soft, self.reformulate) |
|
mask = torch.from_numpy(mask).float().to(self.device) |
|
|
|
labels = torch.eye(self.class_num).to(self.device)[labels, :] |
|
ori_inputs = inputs.clone().to(self.device) |
|
fmix_inputs = inputs.clone().to(self.device) |
|
lam_list = [] |
|
indices = torch.randperm(self.batch_size, device=self.device) |
|
shuffled_inputs = inputs[indices].to(self.device) |
|
shuffled_labels = labels[indices].to(self.device) |
|
|
|
for i in range(self.batch_size): |
|
if np.random.randint(0, 101) > 100 * self.p or (not act): |
|
|
|
lam_list.append(-1) |
|
continue |
|
|
|
x1 = mask * ori_inputs[i] |
|
x2 = (1 - mask) * shuffled_inputs[i] |
|
fmix_inputs[i] = x1 + x2 |
|
|
|
lam_list.append(lam) |
|
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam) |
|
|
|
long_label = labels.argmax(dim=1) |
|
|
|
return fmix_inputs, labels, long_label |
|
|
|
|
|
|
|
class CellMix(object): |
|
def __init__(self, shuffle_p=1.0, class_num=2, strategy='In-place', group_shuffle_size=-1, device='cpu'): |
|
""" |
|
CellMix augmentation arXiv:2301.11513 |
|
:param shuffle_p: chance of trigger augmentation |
|
:param class_num: number of classification categories |
|
:param strategy: 'In-place' or 'Random' to shuffle the relation patches within the batch |
|
:param group_shuffle_size: the size of shuffling group in the batch, -1 to all |
|
:param device: CUDA or CPU |
|
""" |
|
self.p = shuffle_p |
|
self.CLS = class_num |
|
self.device = device |
|
self.strategy = strategy |
|
self.group_shuffle_size = group_shuffle_size |
|
|
|
def __call__(self, inputs, labels, fix_position_ratio=0.5, puzzle_patch_size=32, act=True): |
|
""" |
|
Fix-position in-place shuffling |
|
Perform cross-sample random selection to fix some patches in each image of the batch |
|
After selection, the fixed patches are reserved, the rest patches are batch wise |
|
in-place shuffled and then regrouped with the fixed patches. |
|
cross-sample selection is done by argsort random noise in dim 1 and apply to all image within the batch. |
|
in-place batch-wise shuffle operation is done by argsort random noise in dim 0. |
|
grouped-in-place batch-wise shuffle operation is done by argsort random noise in the batch dimension |
|
|
|
:param inputs: input image tensor, size of [B, 3, H, W], |
|
:param labels: |
|
:param fix_position_ratio: float ratio of the least remaining part of patches |
|
:param puzzle_patch_size: int patch size of shuffle |
|
:param act: set to be False to force not triggering CellMix in validation, set to True to trigger by chance p |
|
|
|
output: x, soft_label, long_label |
|
x : [B, 3, H, W] re-grouped image after cellmix augmentation |
|
soft_label : [B, CLS], soft-label of the class distribution |
|
long_label : [B] hard long-label for general discribe |
|
""" |
|
if np.random.randint(0, 101) > 100 * self.p or (not act): |
|
soft_label = torch.eye(self.CLS).to(self.device)[labels, :] |
|
return inputs, soft_label, labels |
|
|
|
|
|
inputs = patchify(inputs, puzzle_patch_size) |
|
B, num_patches, D = inputs.shape |
|
|
|
|
|
mask = torch.zeros([B, num_patches, self.CLS], device=inputs.device, requires_grad=False) |
|
|
|
|
|
|
|
B_idx = range(B) |
|
mask[B_idx, :, labels] = 1 |
|
|
|
|
|
len_fix_position = int(num_patches * fix_position_ratio) |
|
|
|
|
|
noise = torch.rand(1, num_patches, device=self.device) |
|
noise = torch.repeat_interleave(noise, repeats=B, dim=0) |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_fix = ids_shuffle[:, :len_fix_position] |
|
ids_puzzle = ids_shuffle[:, len_fix_position:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_fixed = torch.gather(inputs, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, D)) |
|
x_puzzle = torch.gather(inputs, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, D)) |
|
mask_fixed = torch.gather(mask, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, self.CLS)) |
|
mask_puzzle = torch.gather(mask, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, self.CLS)) |
|
|
|
if self.strategy == 'In-place' or self.strategy == 'Random': |
|
|
|
B, num_shuffle_patches, D = x_puzzle.shape |
|
|
|
|
|
|
|
noise = torch.rand(B, num_shuffle_patches, device=self.device) |
|
|
|
if self.group_shuffle_size == -1 or self.group_shuffle_size == B: |
|
|
|
|
|
in_place_shuffle_indices = torch.argsort(noise, dim=0) |
|
|
|
else: |
|
assert B > self.group_shuffle_size > 0 and B % self.group_shuffle_size == 0 |
|
grouped_indices_list = [] |
|
for group_idx in range(B // self.group_shuffle_size): |
|
|
|
grouped_noise = noise[group_idx * self.group_shuffle_size: |
|
group_idx * self.group_shuffle_size + self.group_shuffle_size, :] |
|
|
|
|
|
grouped_indices = torch.argsort(grouped_noise, dim=0) |
|
|
|
grouped_indices_list.append(grouped_indices + self.group_shuffle_size * group_idx) |
|
|
|
in_place_shuffle_indices = torch.cat(grouped_indices_list, dim=0) |
|
|
|
|
|
x_puzzle = torch.gather(x_puzzle, dim=0, index=in_place_shuffle_indices.unsqueeze(-1).repeat(1, 1, D)) |
|
mask_puzzle = torch.gather(mask_puzzle, dim=0, |
|
index=in_place_shuffle_indices.unsqueeze(-1).repeat(1, 1, self.CLS)) |
|
else: |
|
print('not a valid CellMix strategy') |
|
|
|
|
|
inputs = torch.cat([x_fixed, x_puzzle], dim=1) |
|
mask = torch.cat([mask_fixed, mask_puzzle], dim=1) |
|
|
|
|
|
inputs = torch.gather(inputs, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, self.CLS)) |
|
|
|
|
|
if self.strategy == 'Random': |
|
B, num_patches, D = inputs.shape |
|
|
|
noise = torch.rand(B, num_patches, device=self.device) |
|
|
|
all_shuffle_indices = torch.argsort(noise, dim=1) |
|
|
|
|
|
inputs = torch.gather(inputs, dim=1, index=all_shuffle_indices.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
else: |
|
pass |
|
|
|
|
|
inputs = unpatchify(inputs, puzzle_patch_size) |
|
|
|
|
|
|
|
|
|
soft_label = mask.sum(dim=1) |
|
soft_label = soft_label / num_patches |
|
|
|
long_label = soft_label.argmax(dim=1) |
|
|
|
return inputs, soft_label, long_label |
|
|
|
|
|
|
|
def get_online_augmentation(augmentation_name, p=0.5, class_num=2, batch_size=4, edge_size=224, device='cpu'): |
|
""" |
|
:param augmentation_name: name of data-augmentation method |
|
:param p: chance of triggering |
|
:param class_num: classification task num |
|
:param batch_size: batch size |
|
:param edge_size: edge size of img |
|
|
|
:param device: cpu or cuda |
|
|
|
其中augmentation_name, class_num, batch_size, edge_size必须提供 |
|
""" |
|
if augmentation_name == 'CellMix-Group': |
|
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=2, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'CellMix-Group4': |
|
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=4, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'CellMix-Split': |
|
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=-1, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'CellMix-Random': |
|
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=2, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'CellMix-Random4': |
|
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=4, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'CellMix-Self': |
|
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=1, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'CellMix-All': |
|
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=-1, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'Cutout': |
|
Augmentation = Cutout(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'CutMix': |
|
Augmentation = CutMix(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'Mixup': |
|
Augmentation = Mixup(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'SaliencyMix': |
|
Augmentation = SaliencyMix(alpha=1, shuffle_p=p, class_num=class_num, batch_size=batch_size, |
|
device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'ResizeMix': |
|
Augmentation = ResizeMix(shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'FMix': |
|
|
|
Augmentation = FMix(shuffle_p=1.0, class_num=class_num, batch_size=batch_size, |
|
size=(edge_size, edge_size), device=device) |
|
return Augmentation |
|
|
|
elif augmentation_name == 'PuzzleMix': |
|
return None |
|
|
|
|
|
|
|
|
|
elif augmentation_name == 'CoMix': |
|
|
|
return None |
|
|
|
elif augmentation_name == 'RandomMix': |
|
|
|
return None |
|
|
|
else: |
|
print('no valid counterparts augmentation selected') |
|
return None |
|
|
|
|
|
if __name__ == '__main__': |
|
''' |
|
Augmentation = get_online_augmentation('CellMix-Split', p=0.5, class_num=2) |
|
output, labels, GT_labels = Augmentation(x, label, fix_position_ratio=0.5, puzzle_patch_size=32, act=True) |
|
|
|
print(labels, GT_labels) |
|
|
|
''' |
|
|
|
x = torch.load("./temp-tensors/warwick.pt") |
|
|
|
label = torch.load("./temp-tensors/warwick_labels.pt") |
|
|
|
|
|
|
|
|
|
Augmentation = get_online_augmentation('CellMix-Group', p=1, class_num=2) |
|
output, labels, GT_labels = Augmentation(x, label, fix_position_ratio=0.5, puzzle_patch_size=32, act=True) |
|
|
|
print(labels, GT_labels) |
|
|
|
composed_img = ToPILImage()(output[0]) |
|
composed_img.show() |
|
composed_img = ToPILImage()(output[1]) |
|
composed_img.show() |
|
composed_img = ToPILImage()(output[2]) |
|
composed_img.show() |
|
composed_img = ToPILImage()(output[3]) |
|
composed_img.show() |
|
|