Spaces:
Paused
Paused
from typing import Dict, List, Optional, NoReturn | |
import torch | |
import lightning.pytorch as pl | |
from torch.utils.data import DataLoader | |
from data.audiotext_dataset import AudioTextDataset | |
class DataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
train_dataset: object, | |
batch_size: int, | |
num_workers: int | |
): | |
r"""Data module. To get one batch of data: | |
code-block:: python | |
data_module.setup() | |
for batch_data_dict in data_module.train_dataloader(): | |
print(batch_data_dict.keys()) | |
break | |
Args: | |
train_sampler: Sampler object | |
train_dataset: Dataset object | |
num_workers: int | |
distributed: bool | |
""" | |
super().__init__() | |
self._train_dataset = train_dataset | |
self.num_workers = num_workers | |
self.batch_size = batch_size | |
self.collate_fn = collate_fn | |
def prepare_data(self): | |
# download, split, etc... | |
# only called on 1 GPU/TPU in distributed | |
pass | |
def setup(self, stage: Optional[str] = None) -> NoReturn: | |
r"""called on every device.""" | |
# make assignments here (val/train/test split) | |
# called on every process in DDP | |
# SegmentSampler is used for selecting segments for training. | |
# On multiple devices, each SegmentSampler samples a part of mini-batch | |
# data. | |
self.train_dataset = self._train_dataset | |
def train_dataloader(self) -> torch.utils.data.DataLoader: | |
r"""Get train loader.""" | |
train_loader = DataLoader( | |
dataset=self.train_dataset, | |
batch_size=self.batch_size, | |
collate_fn=self.collate_fn, | |
num_workers=self.num_workers, | |
pin_memory=True, | |
persistent_workers=False, | |
shuffle=True | |
) | |
return train_loader | |
def val_dataloader(self): | |
# val_split = Dataset(...) | |
# return DataLoader(val_split) | |
pass | |
def test_dataloader(self): | |
# test_split = Dataset(...) | |
# return DataLoader(test_split) | |
pass | |
def teardown(self): | |
# clean up after fit or test | |
# called on every process in DDP | |
pass | |
def collate_fn(list_data_dict): | |
r"""Collate mini-batch data to inputs and targets for training. | |
Args: | |
list_data_dict: e.g., [ | |
{ | |
'text': 'a sound of dog', | |
'waveform': (1, samples), | |
'modality': 'audio_text' | |
} | |
... | |
] | |
Returns: | |
data_dict: e.g. | |
'audio_text': { | |
'text': ['a sound of dog', ...] | |
'waveform': (batch_size, 1, samples) | |
} | |
""" | |
at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text'] | |
at_data_dict = {} | |
if len(at_list_data_dict) > 0: | |
for key in at_list_data_dict[0].keys(): | |
at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict] | |
if key == 'waveform': | |
at_data_dict[key] = torch.stack(at_data_dict[key]) | |
elif key == 'text': | |
at_data_dict[key] = [text for text in at_data_dict[key]] | |
data_dict = { | |
'audio_text': at_data_dict | |
} | |
return data_dict |