File size: 2,123 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import List

from coqpit import Coqpit
from torch.utils.data import Dataset

from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset


def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
    if config.model.lower() in "gan":
        dataset = GANDataset(
            ap=ap,
            items=data_items,
            seq_len=config.seq_len,
            hop_len=ap.hop_length,
            pad_short=config.pad_short,
            conv_pad=config.conv_pad,
            return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
            is_training=not is_eval,
            return_segments=not is_eval,
            use_noise_augment=config.use_noise_augment,
            use_cache=config.use_cache,
            verbose=verbose,
        )
        dataset.shuffle_mapping()
    elif config.model.lower() == "wavegrad":
        dataset = WaveGradDataset(
            ap=ap,
            items=data_items,
            seq_len=config.seq_len,
            hop_len=ap.hop_length,
            pad_short=config.pad_short,
            conv_pad=config.conv_pad,
            is_training=not is_eval,
            return_segments=True,
            use_noise_augment=False,
            use_cache=config.use_cache,
            verbose=verbose,
        )
    elif config.model.lower() == "wavernn":
        dataset = WaveRNNDataset(
            ap=ap,
            items=data_items,
            seq_len=config.seq_len,
            hop_len=ap.hop_length,
            pad=config.model_params.pad,
            mode=config.model_params.mode,
            mulaw=config.model_params.mulaw,
            is_training=not is_eval,
            verbose=verbose,
        )
    else:
        raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
    return dataset