File size: 6,474 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import importlib
import re

from coqpit import Coqpit


def to_camel(text):
    text = text.capitalize()
    return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)


def setup_model(config: Coqpit):
    """Load models directly from configuration."""
    if "discriminator_model" in config and "generator_model" in config:
        MyModel = importlib.import_module("TTS.vocoder.models.gan")
        MyModel = getattr(MyModel, "GAN")
    else:
        MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
        if config.model.lower() == "wavernn":
            MyModel = getattr(MyModel, "Wavernn")
        elif config.model.lower() == "gan":
            MyModel = getattr(MyModel, "GAN")
        elif config.model.lower() == "wavegrad":
            MyModel = getattr(MyModel, "Wavegrad")
        else:
            try:
                MyModel = getattr(MyModel, to_camel(config.model))
            except ModuleNotFoundError as e:
                raise ValueError(f"Model {config.model} not exist!") from e
    print(" > Vocoder Model: {}".format(config.model))
    return MyModel.init_from_config(config)


def setup_generator(c):
    """TODO: use config object as arguments"""
    print(" > Generator Model: {}".format(c.generator_model))
    MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
    MyModel = getattr(MyModel, to_camel(c.generator_model))
    # this is to preserve the Wavernn class name (instead of Wavernn)
    if c.generator_model.lower() in "hifigan_generator":
        model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
    elif c.generator_model.lower() in "melgan_generator":
        model = MyModel(
            in_channels=c.audio["num_mels"],
            out_channels=1,
            proj_kernel=7,
            base_channels=512,
            upsample_factors=c.generator_model_params["upsample_factors"],
            res_kernel=3,
            num_res_blocks=c.generator_model_params["num_res_blocks"],
        )
    elif c.generator_model in "melgan_fb_generator":
        raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
    elif c.generator_model.lower() in "multiband_melgan_generator":
        model = MyModel(
            in_channels=c.audio["num_mels"],
            out_channels=4,
            proj_kernel=7,
            base_channels=384,
            upsample_factors=c.generator_model_params["upsample_factors"],
            res_kernel=3,
            num_res_blocks=c.generator_model_params["num_res_blocks"],
        )
    elif c.generator_model.lower() in "fullband_melgan_generator":
        model = MyModel(
            in_channels=c.audio["num_mels"],
            out_channels=1,
            proj_kernel=7,
            base_channels=512,
            upsample_factors=c.generator_model_params["upsample_factors"],
            res_kernel=3,
            num_res_blocks=c.generator_model_params["num_res_blocks"],
        )
    elif c.generator_model.lower() in "parallel_wavegan_generator":
        model = MyModel(
            in_channels=1,
            out_channels=1,
            kernel_size=3,
            num_res_blocks=c.generator_model_params["num_res_blocks"],
            stacks=c.generator_model_params["stacks"],
            res_channels=64,
            gate_channels=128,
            skip_channels=64,
            aux_channels=c.audio["num_mels"],
            dropout=0.0,
            bias=True,
            use_weight_norm=True,
            upsample_factors=c.generator_model_params["upsample_factors"],
        )
    elif c.generator_model.lower() in "univnet_generator":
        model = MyModel(**c.generator_model_params)
    else:
        raise NotImplementedError(f"Model {c.generator_model} not implemented!")
    return model


def setup_discriminator(c):
    """TODO: use config objekt as arguments"""
    print(" > Discriminator Model: {}".format(c.discriminator_model))
    if "parallel_wavegan" in c.discriminator_model:
        MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
    else:
        MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
    MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
    if c.discriminator_model in "hifigan_discriminator":
        model = MyModel()
    if c.discriminator_model in "random_window_discriminator":
        model = MyModel(
            cond_channels=c.audio["num_mels"],
            hop_length=c.audio["hop_length"],
            uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
            cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
            cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
            window_sizes=c.discriminator_model_params["window_sizes"],
        )
    if c.discriminator_model in "melgan_multiscale_discriminator":
        model = MyModel(
            in_channels=1,
            out_channels=1,
            kernel_sizes=(5, 3),
            base_channels=c.discriminator_model_params["base_channels"],
            max_channels=c.discriminator_model_params["max_channels"],
            downsample_factors=c.discriminator_model_params["downsample_factors"],
        )
    if c.discriminator_model == "residual_parallel_wavegan_discriminator":
        model = MyModel(
            in_channels=1,
            out_channels=1,
            kernel_size=3,
            num_layers=c.discriminator_model_params["num_layers"],
            stacks=c.discriminator_model_params["stacks"],
            res_channels=64,
            gate_channels=128,
            skip_channels=64,
            dropout=0.0,
            bias=True,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
        )
    if c.discriminator_model == "parallel_wavegan_discriminator":
        model = MyModel(
            in_channels=1,
            out_channels=1,
            kernel_size=3,
            num_layers=c.discriminator_model_params["num_layers"],
            conv_channels=64,
            dilation_factor=1,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.2},
            bias=True,
        )
    if c.discriminator_model == "univnet_discriminator":
        model = MyModel()
    return model