|
|
|
|
|
|
|
|
|
import copy |
|
import warnings |
|
|
|
import torch.utils.data |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from src.efficientvit.apps.data_provider.random_resolution import RRSController |
|
from src.efficientvit.models.utils import val2tuple |
|
|
|
__all__ = ["parse_image_size", "random_drop_data", "DataProvider"] |
|
|
|
|
|
def parse_image_size(size: int or str) -> tuple[int, int]: |
|
if isinstance(size, str): |
|
size = [int(val) for val in size.split("-")] |
|
return size[0], size[1] |
|
else: |
|
return val2tuple(size, 2) |
|
|
|
|
|
def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)): |
|
g = torch.Generator() |
|
g.manual_seed(seed) |
|
rand_indexes = torch.randperm(len(dataset), generator=g).tolist() |
|
|
|
dropped_indexes = rand_indexes[:drop_size] |
|
remaining_indexes = rand_indexes[drop_size:] |
|
|
|
dropped_dataset = copy.deepcopy(dataset) |
|
for key in keys: |
|
setattr( |
|
dropped_dataset, |
|
key, |
|
[getattr(dropped_dataset, key)[idx] for idx in dropped_indexes], |
|
) |
|
setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes]) |
|
return dataset, dropped_dataset |
|
|
|
|
|
class DataProvider: |
|
data_keys = ("samples",) |
|
mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} |
|
SUB_SEED = 937162211 |
|
VALID_SEED = 2147483647 |
|
|
|
name: str |
|
|
|
def __init__( |
|
self, |
|
train_batch_size: int, |
|
test_batch_size: int or None, |
|
valid_size: int or float or None, |
|
n_worker: int, |
|
image_size: int or list[int] or str or list[str], |
|
num_replicas: int or None = None, |
|
rank: int or None = None, |
|
train_ratio: float or None = None, |
|
drop_last: bool = False, |
|
): |
|
warnings.filterwarnings("ignore") |
|
super().__init__() |
|
|
|
|
|
self.train_batch_size = train_batch_size |
|
self.test_batch_size = test_batch_size or self.train_batch_size |
|
self.valid_size = valid_size |
|
|
|
|
|
if isinstance(image_size, list): |
|
self.image_size = [parse_image_size(size) for size in image_size] |
|
self.image_size.sort() |
|
RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size) |
|
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1] |
|
else: |
|
self.image_size = parse_image_size(image_size) |
|
RRSController.IMAGE_SIZE_LIST = [self.image_size] |
|
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size |
|
|
|
|
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
|
|
|
|
train_dataset, val_dataset, test_dataset = self.build_datasets() |
|
|
|
if train_ratio is not None and train_ratio < 1.0: |
|
assert 0 < train_ratio < 1 |
|
_, train_dataset = random_drop_data( |
|
train_dataset, |
|
int(train_ratio * len(train_dataset)), |
|
self.SUB_SEED, |
|
self.data_keys, |
|
) |
|
|
|
|
|
self.train = self.build_dataloader( |
|
train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True |
|
) |
|
self.valid = self.build_dataloader( |
|
val_dataset, test_batch_size, n_worker, drop_last=False, train=False |
|
) |
|
self.test = self.build_dataloader( |
|
test_dataset, test_batch_size, n_worker, drop_last=False, train=False |
|
) |
|
if self.valid is None: |
|
self.valid = self.test |
|
self.sub_train = None |
|
|
|
@property |
|
def data_shape(self) -> tuple[int, ...]: |
|
return 3, self.active_image_size[0], self.active_image_size[1] |
|
|
|
def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: |
|
raise NotImplementedError |
|
|
|
def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: |
|
raise NotImplementedError |
|
|
|
def build_datasets(self) -> tuple[any, any, any]: |
|
raise NotImplementedError |
|
|
|
def build_dataloader( |
|
self, |
|
dataset: any or None, |
|
batch_size: int, |
|
n_worker: int, |
|
drop_last: bool, |
|
train: bool, |
|
): |
|
if dataset is None: |
|
return None |
|
if isinstance(self.image_size, list) and train: |
|
from efficientvit.apps.data_provider.random_resolution._data_loader import \ |
|
RRSDataLoader |
|
|
|
dataloader_class = RRSDataLoader |
|
else: |
|
dataloader_class = torch.utils.data.DataLoader |
|
if self.num_replicas is None: |
|
return dataloader_class( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=n_worker, |
|
pin_memory=True, |
|
drop_last=drop_last, |
|
) |
|
else: |
|
sampler = DistributedSampler(dataset, self.num_replicas, self.rank) |
|
return dataloader_class( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
sampler=sampler, |
|
num_workers=n_worker, |
|
pin_memory=True, |
|
drop_last=drop_last, |
|
) |
|
|
|
def set_epoch(self, epoch: int) -> None: |
|
RRSController.set_epoch(epoch, len(self.train)) |
|
if isinstance(self.train.sampler, DistributedSampler): |
|
self.train.sampler.set_epoch(epoch) |
|
|
|
def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None: |
|
self.active_image_size = val2tuple(new_size, 2) |
|
new_transform = self.build_valid_transform(self.active_image_size) |
|
|
|
self.valid.dataset.transform = self.test.dataset.transform = new_transform |
|
|
|
def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]: |
|
if self.valid_size is not None: |
|
if 0 < self.valid_size < 1: |
|
valid_size = int(self.valid_size * len(train_dataset)) |
|
else: |
|
assert self.valid_size >= 1 |
|
valid_size = int(self.valid_size) |
|
train_dataset, val_dataset = random_drop_data( |
|
train_dataset, |
|
valid_size, |
|
self.VALID_SEED, |
|
self.data_keys, |
|
) |
|
val_dataset.transform = valid_transform |
|
else: |
|
val_dataset = None |
|
return train_dataset, val_dataset |
|
|
|
def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any: |
|
|
|
if self.sub_train is None: |
|
self.sub_train = {} |
|
if self.active_image_size in self.sub_train: |
|
return self.sub_train[self.active_image_size] |
|
|
|
|
|
train_dataset = copy.deepcopy(self.train.dataset) |
|
if n_samples < len(train_dataset): |
|
_, train_dataset = random_drop_data( |
|
train_dataset, |
|
n_samples, |
|
self.SUB_SEED, |
|
self.data_keys, |
|
) |
|
RRSController.ACTIVE_SIZE = self.active_image_size |
|
train_dataset.transform = self.build_train_transform( |
|
image_size=self.active_image_size |
|
) |
|
data_loader = self.build_dataloader( |
|
train_dataset, batch_size, self.train.num_workers, True, False |
|
) |
|
|
|
|
|
self.sub_train[self.active_image_size] = [ |
|
data |
|
for data in data_loader |
|
for _ in range(max(1, n_samples // len(train_dataset))) |
|
] |
|
|
|
return self.sub_train[self.active_image_size] |
|
|