Spaces:
Running
Running
File size: 2,123 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from typing import List
from coqpit import Coqpit
from torch.utils.data import Dataset
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
if config.model.lower() in "gan":
dataset = GANDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
is_training=not is_eval,
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
elif config.model.lower() == "wavegrad":
dataset = WaveGradDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
elif config.model.lower() == "wavernn":
dataset = WaveRNNDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_params.pad,
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
is_training=not is_eval,
verbose=verbose,
)
else:
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
return dataset
|