# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han # International Conference on Computer Vision (ICCV), 2023 import torch.distributions from efficientvit.apps.data_provider.augment import rand_bbox from efficientvit.models.utils.random import torch_randint, torch_shuffle __all__ = ["apply_mixup", "mixup", "cutmix"] def apply_mixup( images: torch.Tensor, labels: torch.Tensor, lam: float, mix_type="mixup", ) -> tuple[torch.Tensor, torch.Tensor]: if mix_type == "mixup": return mixup(images, labels, lam) elif mix_type == "cutmix": return cutmix(images, labels, lam) else: raise NotImplementedError def mixup( images: torch.Tensor, target: torch.Tensor, lam: float, ) -> tuple[torch.Tensor, torch.Tensor]: rand_index = torch_shuffle(list(range(0, images.shape[0]))) flipped_images = images[rand_index] flipped_target = target[rand_index] return ( lam * images + (1 - lam) * flipped_images, lam * target + (1 - lam) * flipped_target, ) def cutmix( images: torch.Tensor, target: torch.Tensor, lam: float, ) -> tuple[torch.Tensor, torch.Tensor]: rand_index = torch_shuffle(list(range(0, images.shape[0]))) flipped_images = images[rand_index] flipped_target = target[rand_index] b, _, h, w = images.shape lam_list = [] for i in range(b): bbx1, bby1, bbx2, bby2 = rand_bbox( h=h, w=w, lam=lam, rand_func=torch_randint, ) images[i, :, bby1:bby2, bbx1:bbx2] = flipped_images[i, :, bby1:bby2, bbx1:bbx2] lam_list.append(1 - ((bbx2 - bbx1) * (bby2 - bby1) / (h * w))) lam = torch.Tensor(lam_list).to(images.device).view(b, 1) return images, lam * target + (1 - lam) * flipped_target