Spaces:
Runtime error
Runtime error
File size: 2,552 Bytes
e79b770 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
from pytorch_lightning import LightningDataModule
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
super().__init__()
self.config = config
self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config['data']['num_workers']
def prepare_data(self):
pass
def setup(self, stage=None, output_logs=False):
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config['data']['max_sec'],
pad_val=self.config['data']['pad_val'])
self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path,
# semantic_path=self.dev_semantic_path,
# max_sample=self.config['data']['max_eval_sample'],
# max_sec=self.config['data']['max_sec'],
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size = self.config['train']['batch_size']
sampler = DistributedBucketSampler(
self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=self._train_dataset.collate,
num_workers=self.num_workers,
persistent_workers=True,
prefetch_factor=16
)
def val_dataloader(self):
return DataLoader(
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers,12),
persistent_workers=True,
prefetch_factor=16
)
# 这个会使用到嘛?
def test_dataloader(self):
return DataLoader(
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate)
|