Spaces:
Running
Running
import importlib | |
import re | |
from coqpit import Coqpit | |
def to_camel(text): | |
text = text.capitalize() | |
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) | |
def setup_model(config: Coqpit): | |
"""Load models directly from configuration.""" | |
if "discriminator_model" in config and "generator_model" in config: | |
MyModel = importlib.import_module("TTS.vocoder.models.gan") | |
MyModel = getattr(MyModel, "GAN") | |
else: | |
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) | |
if config.model.lower() == "wavernn": | |
MyModel = getattr(MyModel, "Wavernn") | |
elif config.model.lower() == "gan": | |
MyModel = getattr(MyModel, "GAN") | |
elif config.model.lower() == "wavegrad": | |
MyModel = getattr(MyModel, "Wavegrad") | |
else: | |
try: | |
MyModel = getattr(MyModel, to_camel(config.model)) | |
except ModuleNotFoundError as e: | |
raise ValueError(f"Model {config.model} not exist!") from e | |
print(" > Vocoder Model: {}".format(config.model)) | |
return MyModel.init_from_config(config) | |
def setup_generator(c): | |
"""TODO: use config object as arguments""" | |
print(" > Generator Model: {}".format(c.generator_model)) | |
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) | |
MyModel = getattr(MyModel, to_camel(c.generator_model)) | |
# this is to preserve the Wavernn class name (instead of Wavernn) | |
if c.generator_model.lower() in "hifigan_generator": | |
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) | |
elif c.generator_model.lower() in "melgan_generator": | |
model = MyModel( | |
in_channels=c.audio["num_mels"], | |
out_channels=1, | |
proj_kernel=7, | |
base_channels=512, | |
upsample_factors=c.generator_model_params["upsample_factors"], | |
res_kernel=3, | |
num_res_blocks=c.generator_model_params["num_res_blocks"], | |
) | |
elif c.generator_model in "melgan_fb_generator": | |
raise ValueError("melgan_fb_generator is now fullband_melgan_generator") | |
elif c.generator_model.lower() in "multiband_melgan_generator": | |
model = MyModel( | |
in_channels=c.audio["num_mels"], | |
out_channels=4, | |
proj_kernel=7, | |
base_channels=384, | |
upsample_factors=c.generator_model_params["upsample_factors"], | |
res_kernel=3, | |
num_res_blocks=c.generator_model_params["num_res_blocks"], | |
) | |
elif c.generator_model.lower() in "fullband_melgan_generator": | |
model = MyModel( | |
in_channels=c.audio["num_mels"], | |
out_channels=1, | |
proj_kernel=7, | |
base_channels=512, | |
upsample_factors=c.generator_model_params["upsample_factors"], | |
res_kernel=3, | |
num_res_blocks=c.generator_model_params["num_res_blocks"], | |
) | |
elif c.generator_model.lower() in "parallel_wavegan_generator": | |
model = MyModel( | |
in_channels=1, | |
out_channels=1, | |
kernel_size=3, | |
num_res_blocks=c.generator_model_params["num_res_blocks"], | |
stacks=c.generator_model_params["stacks"], | |
res_channels=64, | |
gate_channels=128, | |
skip_channels=64, | |
aux_channels=c.audio["num_mels"], | |
dropout=0.0, | |
bias=True, | |
use_weight_norm=True, | |
upsample_factors=c.generator_model_params["upsample_factors"], | |
) | |
elif c.generator_model.lower() in "univnet_generator": | |
model = MyModel(**c.generator_model_params) | |
else: | |
raise NotImplementedError(f"Model {c.generator_model} not implemented!") | |
return model | |
def setup_discriminator(c): | |
"""TODO: use config objekt as arguments""" | |
print(" > Discriminator Model: {}".format(c.discriminator_model)) | |
if "parallel_wavegan" in c.discriminator_model: | |
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") | |
else: | |
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) | |
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) | |
if c.discriminator_model in "hifigan_discriminator": | |
model = MyModel() | |
if c.discriminator_model in "random_window_discriminator": | |
model = MyModel( | |
cond_channels=c.audio["num_mels"], | |
hop_length=c.audio["hop_length"], | |
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], | |
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], | |
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], | |
window_sizes=c.discriminator_model_params["window_sizes"], | |
) | |
if c.discriminator_model in "melgan_multiscale_discriminator": | |
model = MyModel( | |
in_channels=1, | |
out_channels=1, | |
kernel_sizes=(5, 3), | |
base_channels=c.discriminator_model_params["base_channels"], | |
max_channels=c.discriminator_model_params["max_channels"], | |
downsample_factors=c.discriminator_model_params["downsample_factors"], | |
) | |
if c.discriminator_model == "residual_parallel_wavegan_discriminator": | |
model = MyModel( | |
in_channels=1, | |
out_channels=1, | |
kernel_size=3, | |
num_layers=c.discriminator_model_params["num_layers"], | |
stacks=c.discriminator_model_params["stacks"], | |
res_channels=64, | |
gate_channels=128, | |
skip_channels=64, | |
dropout=0.0, | |
bias=True, | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
) | |
if c.discriminator_model == "parallel_wavegan_discriminator": | |
model = MyModel( | |
in_channels=1, | |
out_channels=1, | |
kernel_size=3, | |
num_layers=c.discriminator_model_params["num_layers"], | |
conv_channels=64, | |
dilation_factor=1, | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
bias=True, | |
) | |
if c.discriminator_model == "univnet_discriminator": | |
model = MyModel() | |
return model | |