# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han # International Conference on Computer Vision (ICCV), 2023 import copy import math import os import torchvision.transforms as transforms from torchvision.datasets import ImageFolder from efficientvit.apps.data_provider import DataProvider from efficientvit.apps.data_provider.augment import RandAug from efficientvit.apps.data_provider.random_resolution import MyRandomResizedCrop, get_interpolate from efficientvit.apps.utils import partial_update_config from efficientvit.models.utils import val2list __all__ = ["ImageNetDataProvider"] class ImageNetDataProvider(DataProvider): name = "imagenet" data_dir = "/dataset/imagenet" n_classes = 1000 _DEFAULT_RRC_CONFIG = { "train_interpolate": "random", "test_interpolate": "bicubic", "test_crop_ratio": 1.0, } def __init__( self, data_dir: str or None = None, rrc_config: dict or None = None, data_aug: dict or list[dict] or None = None, ########################################### train_batch_size=128, test_batch_size=128, valid_size: int or float or None = None, n_worker=8, image_size: int or list[int] = 224, num_replicas: int or None = None, rank: int or None = None, train_ratio: float or None = None, drop_last: bool = False, ): self.data_dir = data_dir or self.data_dir self.rrc_config = partial_update_config( copy.deepcopy(self._DEFAULT_RRC_CONFIG), rrc_config or {}, ) self.data_aug = data_aug super().__init__( train_batch_size, test_batch_size, valid_size, n_worker, image_size, num_replicas, rank, train_ratio, drop_last, ) def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: image_size = (image_size or self.active_image_size)[0] crop_size = int(math.ceil(image_size / self.rrc_config["test_crop_ratio"])) return transforms.Compose( [ transforms.Resize( crop_size, interpolation=get_interpolate(self.rrc_config["test_interpolate"]), ), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(**self.mean_std), ] ) def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: image_size = image_size or self.image_size # random_resize_crop -> random_horizontal_flip train_transforms = [ MyRandomResizedCrop(interpolation=self.rrc_config["train_interpolate"]), transforms.RandomHorizontalFlip(), ] # data augmentation post_aug = [] if self.data_aug is not None: for aug_op in val2list(self.data_aug): if aug_op["name"] == "randaug": data_aug = RandAug(aug_op, mean=self.mean_std["mean"]) elif aug_op["name"] == "erase": from timm.data.random_erasing import RandomErasing random_erase = RandomErasing(aug_op["p"], device="cpu") post_aug.append(random_erase) data_aug = None else: raise NotImplementedError if data_aug is not None: train_transforms.append(data_aug) train_transforms = [ *train_transforms, transforms.ToTensor(), transforms.Normalize(**self.mean_std), *post_aug, ] return transforms.Compose(train_transforms) def build_datasets(self) -> tuple[any, any, any]: train_transform = self.build_train_transform() valid_transform = self.build_valid_transform() train_dataset = ImageFolder(os.path.join(self.data_dir, "train"), train_transform) test_dataset = ImageFolder(os.path.join(self.data_dir, "val"), valid_transform) train_dataset, val_dataset = self.sample_val_dataset(train_dataset, valid_transform) return train_dataset, val_dataset, test_dataset