File size: 28,387 Bytes
edcf5ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 |
"""
Online Augmentations May 23rd 2023 21:30
ref:
CutOut, Mixup, CutMix based on
https://blog.csdn.net/cp1314971/article/details/106612060
"""
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from scipy.special import perm
from torchvision.transforms import Resize
from torchvision.transforms import ToPILImage, ToTensor
from utils.visual_usage import patchify, unpatchify
from utils.fmix import sample_mask, FMixBase # Fmix
# generate random bounding box
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int64(W * cut_rat)
cut_h = np.int64(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def saliency_bbox(img, lam):
size = img.size()
W = size[1]
H = size[2]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# initialize OpenCV's static fine grained saliency detector and
# compute the saliency map
temp_img = img.cpu().numpy().transpose(1, 2, 0)
saliency = cv2.saliency.StaticSaliencyFineGrained_create()
(success, saliencyMap) = saliency.computeSaliency(temp_img)
saliencyMap = (saliencyMap * 255).astype("uint8")
maximum_indices = np.unravel_index(np.argmax(saliencyMap, axis=None), saliencyMap.shape)
x = maximum_indices[0]
y = maximum_indices[1]
bbx1 = np.clip(x - cut_w // 2, 0, W)
bby1 = np.clip(y - cut_h // 2, 0, H)
bbx2 = np.clip(x + cut_w // 2, 0, W)
bby2 = np.clip(y + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
# augmentation methods
class Cutout(object):
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
Cutout augmentation arXiv:1708.04552
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
cutout_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) # get random bbox
cutout_inputs[i, :, bbx1:bbx2, bby1:bby2] = 0
# update the ratio of (area of ori_image on new masked image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
long_label = labels.argmax(dim=1)
# NOTICE cutout use long label and ori_crossentropy instead of soft-label and soft-label_crossentropy
return cutout_inputs, long_label, long_label
class CutMix(object):
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
CutMix augmentation arXiv:1905.04899
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p * (perm(self.batch_size, self.batch_size)
/ (perm(self.batch_size, self.batch_size) -
perm(self.batch_size - 1, self.batch_size - 1)))
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
cutmix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) # get random bbox
cutmix_inputs[i, :, bbx1:bbx2, bby1:bby2] = \
shuffled_inputs[i, :, bbx1:bbx2, bby1:bby2]
# update the ratio of (area of ori_image on new image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return cutmix_inputs, labels, long_label
class Mixup(object):
def __init__(self, alpha=2, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
Mixup augmentation arXiv:1710.09412
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p * (perm(self.batch_size, self.batch_size)
/ (perm(self.batch_size, self.batch_size) -
perm(self.batch_size - 1, self.batch_size - 1)))
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
mixup_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
lam_list.append(lam)
mixup_inputs[i] = ori_inputs[i] * lam + shuffled_inputs[i] * (1 - lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return mixup_inputs, labels, long_label
class SaliencyMix(object):
def __init__(self, alpha=1, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
SaliencyMix augmentation arXiv:2006.01791
:param alpha: alpha
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
# ori batch_size=128
self.alpha = alpha
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
saliencymix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act) or self.alpha <= 0:
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.beta(self.alpha, self.alpha)
bbx1, bby1, bbx2, bby2 = saliency_bbox(shuffled_inputs[i], lam) # get random bbox
saliencymix_inputs[i, :, bbx1:bbx2, bby1:bby2] = \
shuffled_inputs[i, :, bbx1:bbx2, bby1:bby2]
# update the ratio of (area of ori_image on new image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return saliencymix_inputs, labels, long_label
class ResizeMix(object):
def __init__(self, shuffle_p=1.0, class_num=2, batch_size=4, device='cpu'):
"""
ResizeMix augmentation arXiv:2012.11101
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param device: CUDA or CPU
"""
# ori batch_size=512
self.class_num = class_num
self.batch_size = batch_size
# calibrate the trigger chance of p, new ratio is the change of operation occur in each batch
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, alpha=0.1, beta=0.8, act=True):
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device) # duplicate inputs for ori inputs
resizemix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
lam = np.random.uniform(alpha, beta)
# lam = 1 - lam
bbx1, bby1, bbx2, bby2 = rand_bbox(ori_inputs.size(), lam) # get random bbox
# resizer by torchvision
torch_resize = Resize([bbx2 - bbx1, bby2 - bby1])
# Tensor -> PIL -> resize -> Tensor
re_pil_image = torch_resize(ToPILImage()(shuffled_inputs[i]))
resizemix_inputs[i, :, bbx1:bbx2, bby1:bby2] = ToTensor()(re_pil_image)
# update the ratio of (area of ori_image on new image) for soft-label
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (ori_inputs.size()[2] * ori_inputs.size()[3]))
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
return resizemix_inputs, labels, long_label
class FMix(FMixBase):
def __init__(self, shuffle_p=1.0, class_num=2, batch_size=4, decay_power=3, alpha=1, size=(32, 32),
max_soft=0.0, reformulate=False, device='cpu'):
"""
FMix augmentation arXiv:2002.12047
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param batch_size: batch_size of training
:param decay_power: decay_power
:param alpha: alpha
:param size: size of patch
:param max_soft: max_soft
:param reformulate: reformulate
:param device: CUDA or CPU
"""
# ori batch_size=128
super().__init__(decay_power, alpha, size, max_soft, reformulate)
self.class_num = class_num
self.batch_size = batch_size
self.p = shuffle_p
self.device = torch.device(device)
def __call__(self, inputs, labels, alpha=1, act=True):
# Sample mask and generate random permutation
lam, mask = sample_mask(self.alpha, self.decay_power, self.size, self.max_soft, self.reformulate)
mask = torch.from_numpy(mask).float().to(self.device)
labels = torch.eye(self.class_num).to(self.device)[labels, :] # one-hot hard label
ori_inputs = inputs.clone().to(self.device)
fmix_inputs = inputs.clone().to(self.device) # duplicate inputs for outputs
lam_list = [] # a list to record operating ratio
indices = torch.randperm(self.batch_size, device=self.device) # shuffle indices
shuffled_inputs = inputs[indices].to(self.device)
shuffled_labels = labels[indices].to(self.device)
for i in range(self.batch_size):
if np.random.randint(0, 101) > 100 * self.p or (not act):
# trigger the augmentation operation
lam_list.append(-1)
continue
x1 = mask * ori_inputs[i]
x2 = (1 - mask) * shuffled_inputs[i]
fmix_inputs[i] = x1 + x2
lam_list.append(lam)
labels[i] = labels[i] * lam + shuffled_labels[i] * (1 - lam)
long_label = labels.argmax(dim=1)
# print('lam:', lam)
return fmix_inputs, labels, long_label
# CellMix
class CellMix(object):
def __init__(self, shuffle_p=1.0, class_num=2, strategy='In-place', group_shuffle_size=-1, device='cpu'):
"""
CellMix augmentation arXiv:2301.11513
:param shuffle_p: chance of trigger augmentation
:param class_num: number of classification categories
:param strategy: 'In-place' or 'Random' to shuffle the relation patches within the batch
:param group_shuffle_size: the size of shuffling group in the batch, -1 to all
:param device: CUDA or CPU
"""
self.p = shuffle_p
self.CLS = class_num # classification category number of the task
self.device = device
self.strategy = strategy # 'In-place' or 'Random'
self.group_shuffle_size = group_shuffle_size # -1 for whole batch
def __call__(self, inputs, labels, fix_position_ratio=0.5, puzzle_patch_size=32, act=True):
"""
Fix-position in-place shuffling
Perform cross-sample random selection to fix some patches in each image of the batch
After selection, the fixed patches are reserved, the rest patches are batch wise
in-place shuffled and then regrouped with the fixed patches.
cross-sample selection is done by argsort random noise in dim 1 and apply to all image within the batch.
in-place batch-wise shuffle operation is done by argsort random noise in dim 0.
grouped-in-place batch-wise shuffle operation is done by argsort random noise in the batch dimension
:param inputs: input image tensor, size of [B, 3, H, W],
:param labels:
:param fix_position_ratio: float ratio of the least remaining part of patches
:param puzzle_patch_size: int patch size of shuffle
:param act: set to be False to force not triggering CellMix in validation, set to True to trigger by chance p
output: x, soft_label, long_label
x : [B, 3, H, W] re-grouped image after cellmix augmentation
soft_label : [B, CLS], soft-label of the class distribution
long_label : [B] hard long-label for general discribe
"""
if np.random.randint(0, 101) > 100 * self.p or (not act):
soft_label = torch.eye(self.CLS).to(self.device)[labels, :] # one-hot hard label
return inputs, soft_label, labels
# Break img into puzzle patches with the size of puzzle_patch_size [B, num_patches, D]
inputs = patchify(inputs, puzzle_patch_size)
B, num_patches, D = inputs.shape
# generate the persudo-mask: in cls dim only the k dim is
mask = torch.zeros([B, num_patches, self.CLS], device=inputs.device, requires_grad=False) # no grad
# mask of patches: (B, num_patches, cls) (cls)=[0,mask_area,0,....]
# transform to persudo-mask
B_idx = range(B)
mask[B_idx, :, labels] = 1
# num of fix_position puzzle patches
len_fix_position = int(num_patches * fix_position_ratio)
# create a noise tensor to prepare shuffle idx of puzzle patches
noise = torch.rand(1, num_patches, device=self.device)
noise = torch.repeat_interleave(noise, repeats=B, dim=0)
# based on the batch sequence's shape, the noise tensor get a series idx matrix by sort
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
# sort the idx matrix again, we can obtain the original location idx matrix before assignment
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_fix = ids_shuffle[:, :len_fix_position] # [B,num_patches] -> [B,fix_patches]
ids_puzzle = ids_shuffle[:, len_fix_position:] # [B,num_patches] -> [B,puzzle_patches]
# set puzzle patch
# ids_?.unsqueeze(-1).repeat(1, 1, D)
# [B,?_patches] -> [B,?_patches,1] (at each place with the idx of ori patch) -> [B,?_patches,D]
# torch.gather to select patche groups x_fixed of [B,fix_patches,D] and x_puzzle of [B,puzzle_patches,D]
x_fixed = torch.gather(inputs, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, D))
x_puzzle = torch.gather(inputs, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, D))
mask_fixed = torch.gather(mask, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, self.CLS))
mask_puzzle = torch.gather(mask, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, self.CLS))
if self.strategy == 'In-place' or self.strategy == 'Random':
# the In-place strategy shuffles the relation patches within their location, among the batch index
B, num_shuffle_patches, D = x_puzzle.shape
# create a noise tensor to prepare shuffle idx of puzzle patches
# [B, num_shuffle_patches] noise in [0, 1]
noise = torch.rand(B, num_shuffle_patches, device=self.device)
if self.group_shuffle_size == -1 or self.group_shuffle_size == B: # CellMix-Split
# sort the noise matrix, obtain a index assignment for shuffle,
# shuffle dim 0 of entire noise (among all the batch)
in_place_shuffle_indices = torch.argsort(noise, dim=0)
else: # CellMix-Group
assert B > self.group_shuffle_size > 0 and B % self.group_shuffle_size == 0
grouped_indices_list = []
for group_idx in range(B // self.group_shuffle_size):
# group the noise by self.group_shuffle_size: [group_shuffle_size,N]
grouped_noise = noise[group_idx * self.group_shuffle_size:
group_idx * self.group_shuffle_size + self.group_shuffle_size, :]
# sort each grouped_noise matrix, obtain a index assignment for shuffle,
# now the shuffle dim is 0 (among the batch within the group)
grouped_indices = torch.argsort(grouped_noise, dim=0)
# put grouped_noise matrix into the list
grouped_indices_list.append(grouped_indices + self.group_shuffle_size * group_idx)
# stack(cat) the group indices(from list) back to tensor
in_place_shuffle_indices = torch.cat(grouped_indices_list, dim=0)
# torch.gather to achieve shuffle (taking all the idx base on a shuffled indices)
x_puzzle = torch.gather(x_puzzle, dim=0, index=in_place_shuffle_indices.unsqueeze(-1).repeat(1, 1, D))
mask_puzzle = torch.gather(mask_puzzle, dim=0,
index=in_place_shuffle_indices.unsqueeze(-1).repeat(1, 1, self.CLS))
else:
print('not a valid CellMix strategy')
# pack up all puzzle patches
inputs = torch.cat([x_fixed, x_puzzle], dim=1)
mask = torch.cat([mask_fixed, mask_puzzle], dim=1)
# unshuffle to restore the fixed positions
inputs = torch.gather(inputs, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
# torch.gather to generate restored binary mask
mask = torch.gather(mask, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, self.CLS))
# CellMix random strategy randomly shuffle the image patches (after cellmix in-place shuffle)
if self.strategy == 'Random':
B, num_patches, D = inputs.shape
# create a noise tensor to prepare shuffle idx of puzzle patches
noise = torch.rand(B, num_patches, device=self.device) # [num_patches,B] noise in [0, 1]
# sort the noise matrix, obtain a index assignment for shuffle, now the shuffle dim is 1 (with the batch)
all_shuffle_indices = torch.argsort(noise, dim=1)
# ids_shuffle shape of [B,N], in N is idx
# torch.gather to shuffle
inputs = torch.gather(inputs, dim=1, index=all_shuffle_indices.unsqueeze(-1).repeat(1, 1, D))
# no need to torch the mask, because its patch-wise shuffle within each sample
else: # when strategy == 'In-place'
pass
# unpatchify to obtain puzzle images and their mask
inputs = unpatchify(inputs, puzzle_patch_size) # restore to image size:B,3,224,224/ B,3,384,384
# transform soft-mask to soft-label
# calaculate a composed label with a conjugate design
# [B, num_patches, CLS]->(B, CLS)
soft_label = mask.sum(dim=1) # (B, CLS)
soft_label = soft_label / num_patches
# long_label, as a data-augmentation requirement
long_label = soft_label.argmax(dim=1)
return inputs, soft_label, long_label
# ask func
def get_online_augmentation(augmentation_name, p=0.5, class_num=2, batch_size=4, edge_size=224, device='cpu'):
"""
:param augmentation_name: name of data-augmentation method
:param p: chance of triggering
:param class_num: classification task num
:param batch_size: batch size
:param edge_size: edge size of img
:param device: cpu or cuda
其中augmentation_name, class_num, batch_size, edge_size必须提供
"""
if augmentation_name == 'CellMix-Group':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=2,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Group4':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=4,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Split':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='In-place', group_shuffle_size=-1,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Random':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=2,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Random4':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=4,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-Self':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=1,
device=device)
return Augmentation
elif augmentation_name == 'CellMix-All':
Augmentation = CellMix(shuffle_p=p, class_num=class_num, strategy='Random', group_shuffle_size=-1,
device=device)
return Augmentation
elif augmentation_name == 'Cutout':
Augmentation = Cutout(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'CutMix':
Augmentation = CutMix(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'Mixup':
Augmentation = Mixup(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'SaliencyMix':
Augmentation = SaliencyMix(alpha=1, shuffle_p=p, class_num=class_num, batch_size=batch_size,
device=device) # alpha实际为源代码中beta
return Augmentation
elif augmentation_name == 'ResizeMix':
Augmentation = ResizeMix(shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
return Augmentation
elif augmentation_name == 'FMix':
# FMIX p=1.0 beacuse the chance of trigger is determined inside its own design
Augmentation = FMix(shuffle_p=1.0, class_num=class_num, batch_size=batch_size,
size=(edge_size, edge_size), device=device)
return Augmentation
elif augmentation_name == 'PuzzleMix':
return None
# fixme: all related parts have been taken out seperately
# Augmentation = PuzzleMix(alpha=2, shuffle_p=p, class_num=class_num, batch_size=batch_size, device=device)
# return Augmentation
elif augmentation_name == 'CoMix':
# TODO CoMix
return None
elif augmentation_name == 'RandomMix':
# TODO RandomMix
return None
else:
print('no valid counterparts augmentation selected')
return None
if __name__ == '__main__':
'''
Augmentation = get_online_augmentation('CellMix-Split', p=0.5, class_num=2)
output, labels, GT_labels = Augmentation(x, label, fix_position_ratio=0.5, puzzle_patch_size=32, act=True)
print(labels, GT_labels)
'''
x = torch.load("./temp-tensors/warwick.pt")
# print(x.shape)
label = torch.load("./temp-tensors/warwick_labels.pt")
# print(label)
# Augmentation = get_online_augmentation('ResizeMix', p=0.5, class_num=2)
# output, labels, GT_labels = Augmentation(x, label, act=True)
Augmentation = get_online_augmentation('CellMix-Group', p=1, class_num=2)
output, labels, GT_labels = Augmentation(x, label, fix_position_ratio=0.5, puzzle_patch_size=32, act=True)
print(labels, GT_labels)
composed_img = ToPILImage()(output[0])
composed_img.show()
composed_img = ToPILImage()(output[1])
composed_img.show()
composed_img = ToPILImage()(output[2])
composed_img.show()
composed_img = ToPILImage()(output[3])
composed_img.show()
|