|
""" Datasets for core experimental results """ |
|
from functools import partial |
|
from pathlib import Path |
|
import torch |
|
import torchaudio.functional as TF |
|
import torchvision |
|
from einops import rearrange |
|
|
|
from ..utils.util import is_list |
|
|
|
|
|
def deprecated(cls_or_func): |
|
def _deprecated(*args, **kwargs): |
|
print(f"{cls_or_func} is deprecated") |
|
return cls_or_func(*args, **kwargs) |
|
return _deprecated |
|
|
|
|
|
|
|
default_data_path = Path(__file__).parent.parent.parent.absolute() |
|
default_data_path = default_data_path / "raw_data" |
|
|
|
|
|
class DefaultCollateMixin: |
|
"""Controls collating in the DataLoader |
|
|
|
The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. |
|
Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a |
|
_dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the |
|
rest of the arguments into the constructor. |
|
""" |
|
|
|
@classmethod |
|
def _collate_callback(cls, x, *args, **kwargs): |
|
""" |
|
Modify the behavior of the default _collate method. |
|
""" |
|
return x |
|
|
|
_collate_arg_names = [] |
|
|
|
@classmethod |
|
def _return_callback(cls, return_value, *args, **kwargs): |
|
""" |
|
Modify the return value of the collate_fn. |
|
Assign a name to each element of the returned tuple beyond the (x, y) pairs |
|
See InformerSequenceDataset for an example of this being used |
|
""" |
|
x, y, *z = return_value |
|
assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" |
|
return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} |
|
|
|
@classmethod |
|
def _collate(cls, batch, *args, **kwargs): |
|
|
|
elem = batch[0] |
|
if isinstance(elem, torch.Tensor): |
|
out = None |
|
if torch.utils.data.get_worker_info() is not None: |
|
|
|
|
|
numel = sum(x.numel() for x in batch) |
|
storage = elem.storage()._new_shared(numel) |
|
out = elem.new(storage) |
|
x = torch.stack(batch, dim=0, out=out) |
|
|
|
|
|
x = cls._collate_callback(x, *args, **kwargs) |
|
|
|
return x |
|
else: |
|
return torch.tensor(batch) |
|
|
|
@classmethod |
|
def _collate_fn(cls, batch, *args, **kwargs): |
|
""" |
|
Default collate function. |
|
Generally accessed by the dataloader() methods to pass into torch DataLoader |
|
|
|
Arguments: |
|
batch: list of (x, y) pairs |
|
args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback |
|
""" |
|
x, y, *z = zip(*batch) |
|
|
|
x = cls._collate(x, *args, **kwargs) |
|
y = cls._collate(y) |
|
z = [cls._collate(z_) for z_ in z] |
|
|
|
return_value = (x, y, *z) |
|
return cls._return_callback(return_value, *args, **kwargs) |
|
|
|
|
|
collate_args = [] |
|
|
|
def _dataloader(self, dataset, **loader_args): |
|
collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} |
|
loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} |
|
loader_cls = loader_registry[loader_args.pop("_name_", None)] |
|
return loader_cls( |
|
dataset=dataset, |
|
collate_fn=partial(self._collate_fn, **collate_args), |
|
**loader_args, |
|
) |
|
|
|
|
|
class SequenceResolutionCollateMixin(DefaultCollateMixin): |
|
"""self.collate_fn(resolution) produces a collate function that subsamples elements of the sequence""" |
|
|
|
@classmethod |
|
def _collate_callback(cls, x, resolution=None): |
|
if resolution is None: |
|
pass |
|
elif is_list(resolution): |
|
|
|
x = x.squeeze(-1) |
|
L = x.size(1) |
|
x = x[:, ::resolution[0]] |
|
_L = L // resolution[0] |
|
for r in resolution[1:]: |
|
x = TF.resample(x, _L, L//r) |
|
_L = L // r |
|
x = x.unsqueeze(-1) |
|
else: |
|
|
|
assert x.ndim >= 2 |
|
n_resaxes = max(1, x.ndim - 2) |
|
|
|
lhs = "b " + " ".join([f"(l{i} res{i})" for i in range(n_resaxes)]) + " ..." |
|
rhs = " ".join([f"res{i}" for i in range(n_resaxes)]) + " b " + " ".join([f"l{i}" for i in range(n_resaxes)]) + " ..." |
|
x = rearrange(x, lhs + " -> " + rhs, **{f'res{i}': resolution for i in range(n_resaxes)}) |
|
x = x[tuple([0] * n_resaxes)] |
|
|
|
return x |
|
|
|
@classmethod |
|
def _return_callback(cls, return_value, resolution=None): |
|
return (*return_value, {"rate": resolution}) |
|
|
|
collate_args = ['resolution'] |
|
|
|
|
|
class ImageResolutionCollateMixin(SequenceResolutionCollateMixin): |
|
"""self.collate_fn(resolution, img_size) produces a collate function that resizes inputs to size img_size/resolution""" |
|
|
|
_interpolation = torchvision.transforms.InterpolationMode.BILINEAR |
|
_antialias = True |
|
|
|
@classmethod |
|
def _collate_callback(cls, x, resolution=None, img_size=None, channels_last=True): |
|
if x.ndim < 4: |
|
return super()._collate_callback(x, resolution=resolution) |
|
if img_size is None: |
|
x = super()._collate_callback(x, resolution=resolution) |
|
else: |
|
x = rearrange(x, 'b ... c -> b c ...') if channels_last else x |
|
_size = round(img_size/resolution) |
|
x = torchvision.transforms.functional.resize( |
|
x, |
|
size=[_size, _size], |
|
interpolation=cls._interpolation, |
|
antialias=cls._antialias, |
|
) |
|
x = rearrange(x, 'b c ... -> b ... c') if channels_last else x |
|
return x |
|
|
|
@classmethod |
|
def _return_callback(cls, return_value, resolution=None, img_size=None, channels_last=True): |
|
return (*return_value, {"rate": resolution}) |
|
|
|
collate_args = ['resolution', 'img_size', 'channels_last'] |
|
|
|
|
|
class TBPTTDataLoader(torch.utils.data.DataLoader): |
|
""" |
|
Adapted from https://github.com/deepsound-project/samplernn-pytorch |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
batch_size, |
|
chunk_len, |
|
overlap_len, |
|
*args, |
|
**kwargs |
|
): |
|
super().__init__(dataset, batch_size, *args, **kwargs) |
|
assert chunk_len is not None and overlap_len is not None, "TBPTTDataLoader: chunk_len and overlap_len must be specified." |
|
|
|
|
|
self.zero = dataset.zero if hasattr(dataset, "zero") else 0 |
|
|
|
|
|
self.chunk_len = chunk_len |
|
|
|
|
|
self.overlap_len = overlap_len |
|
|
|
def __iter__(self): |
|
for batch in super().__iter__(): |
|
x, y, z = batch |
|
|
|
|
|
pad = lambda x, val: torch.cat([x.new_zeros((x.shape[0], self.overlap_len - 1, *x.shape[2:])) + val, x], dim=1) |
|
x = pad(x, self.zero) |
|
y = pad(y, 0) |
|
z = { k: pad(v, 0) for k, v in z.items() if v.ndim > 1 } |
|
_, seq_len, *_ = x.shape |
|
|
|
reset = True |
|
|
|
for seq_begin in list(range(self.overlap_len - 1, seq_len, self.chunk_len))[:-1]: |
|
from_index = seq_begin - self.overlap_len + 1 |
|
to_index = seq_begin + self.chunk_len |
|
|
|
|
|
if self.overlap_len > 0: |
|
to_index = min(to_index, seq_len - ((seq_len - self.overlap_len + 1) % self.overlap_len)) |
|
|
|
x_chunk = x[:, from_index:to_index] |
|
if len(y.shape) == 3: |
|
y_chunk = y[:, seq_begin:to_index] |
|
else: |
|
y_chunk = y |
|
z_chunk = {k: v[:, from_index:to_index] for k, v in z.items() if len(v.shape) > 1} |
|
|
|
yield (x_chunk, y_chunk, {**z_chunk, "reset": reset}) |
|
|
|
reset = False |
|
|
|
def __len__(self): |
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
class SequenceDataset(DefaultCollateMixin): |
|
registry = {} |
|
_name_ = NotImplementedError("Dataset must have shorthand name") |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def init_defaults(self): |
|
return {} |
|
|
|
|
|
def __init_subclass__(cls, **kwargs): |
|
super().__init_subclass__(**kwargs) |
|
cls.registry[cls._name_] = cls |
|
|
|
def __init__(self, _name_, data_dir=None, **dataset_cfg): |
|
assert _name_ == self._name_ |
|
self.data_dir = Path(data_dir).absolute() if data_dir is not None else None |
|
|
|
|
|
init_args = self.init_defaults.copy() |
|
init_args.update(dataset_cfg) |
|
for k, v in init_args.items(): |
|
setattr(self, k, v) |
|
|
|
|
|
self.dataset_train = self.dataset_val = self.dataset_test = None |
|
|
|
self.init() |
|
|
|
def init(self): |
|
"""Hook called at end of __init__, override this instead of __init__""" |
|
pass |
|
|
|
def setup(self): |
|
"""This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" |
|
raise NotImplementedError |
|
|
|
def split_train_val(self, val_split): |
|
""" |
|
Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. |
|
""" |
|
train_len = int(len(self.dataset_train) * (1.0 - val_split)) |
|
self.dataset_train, self.dataset_val = torch.utils.data.random_split( |
|
self.dataset_train, |
|
(train_len, len(self.dataset_train) - train_len), |
|
generator=torch.Generator().manual_seed( |
|
getattr(self, "seed", 42) |
|
), |
|
) |
|
|
|
def train_dataloader(self, **kwargs): |
|
return self._train_dataloader(self.dataset_train, **kwargs) |
|
|
|
def _train_dataloader(self, dataset, **kwargs): |
|
if dataset is None: return |
|
kwargs['shuffle'] = 'sampler' not in kwargs |
|
return self._dataloader(dataset, **kwargs) |
|
|
|
def val_dataloader(self, **kwargs): |
|
return self._eval_dataloader(self.dataset_val, **kwargs) |
|
|
|
def test_dataloader(self, **kwargs): |
|
return self._eval_dataloader(self.dataset_test, **kwargs) |
|
|
|
def _eval_dataloader(self, dataset, **kwargs): |
|
if dataset is None: return |
|
|
|
return self._dataloader(dataset, **kwargs) |
|
|
|
def __str__(self): |
|
return self._name_ |
|
|
|
|
|
class ResolutionSequenceDataset(SequenceDataset, SequenceResolutionCollateMixin): |
|
|
|
def _train_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): |
|
if train_resolution is None: train_resolution = [1] |
|
if not is_list(train_resolution): train_resolution = [train_resolution] |
|
assert len(train_resolution) == 1, "Only one train resolution supported for now." |
|
return super()._train_dataloader(dataset, resolution=train_resolution[0], **kwargs) |
|
|
|
def _eval_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): |
|
if dataset is None: return |
|
if eval_resolutions is None: eval_resolutions = [1] |
|
if not is_list(eval_resolutions): eval_resolutions = [eval_resolutions] |
|
|
|
dataloaders = [] |
|
for resolution in eval_resolutions: |
|
dataloaders.append(super()._eval_dataloader(dataset, resolution=resolution, **kwargs)) |
|
|
|
return ( |
|
{ |
|
None if res == 1 else str(res): dl |
|
for res, dl in zip(eval_resolutions, dataloaders) |
|
} |
|
if dataloaders is not None else None |
|
) |
|
|
|
|
|
class ImageResolutionSequenceDataset(ResolutionSequenceDataset, ImageResolutionCollateMixin): |
|
pass |
|
|
|
|
|
|
|
loader_registry = { |
|
"tbptt": TBPTTDataLoader, |
|
None: torch.utils.data.DataLoader, |
|
} |
|
|
|
|