TTS-Model / viXTTS /TTS /config /__init__.py
duyv's picture
Upload 381 files
813828b verified
raw
history blame
4.45 kB
import json
import os
import re
from typing import Dict
import fsspec
import yaml
from coqpit import Coqpit
from TTS.config.shared_configs import *
from TTS.utils.generic_utils import find_module
def read_json_with_comments(json_path):
"""for backward compat."""
# fallback to json
with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments but not urls with //
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
return json.loads(input_str)
def register_config(model_name: str) -> Coqpit:
"""Find the right config for the given model name.
Args:
model_name (str): Model name.
Raises:
ModuleNotFoundError: No matching config for the model name.
Returns:
Coqpit: config class.
"""
config_class = None
config_name = model_name + "_config"
# TODO: fix this
if model_name == "xtts":
from TTS.tts.configs.xtts_config import XttsConfig
config_class = XttsConfig
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"]
for path in paths:
try:
config_class = find_module(path, config_name)
except ModuleNotFoundError:
pass
if config_class is None:
raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.")
return config_class
def _process_model_name(config_dict: Dict) -> str:
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
Args:
config_dict (Dict): A dictionary including the config fields.
Returns:
str: Formatted modelname.
"""
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
return model_name
def load_config(config_path: str) -> Coqpit:
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
to find the corresponding Config class. Then initialize the Config.
Args:
config_path (str): path to the config file.
Raises:
TypeError: given config file has an unknown type.
Returns:
Coqpit: TTS config object.
"""
config_dict = {}
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
elif ext == ".json":
try:
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
except json.decoder.JSONDecodeError:
# backwards compat.
data = read_json_with_comments(config_path)
else:
raise TypeError(f" [!] Unknown config file type {ext}")
config_dict.update(data)
model_name = _process_model_name(config_dict)
config_class = register_config(model_name.lower())
config = config_class()
config.from_dict(config_dict)
return config
def check_config_and_model_args(config, arg_name, value):
"""Check the give argument in `config.model_args` if exist or in `config` for
the given value.
Return False if the argument does not exist in `config.model_args` or `config`.
This is to patch up the compatibility between models with and without `model_args`.
TODO: Remove this in the future with a unified approach.
"""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name] == value
if hasattr(config, arg_name):
return config[arg_name] == value
return False
def get_from_config_or_model_args(config, arg_name):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
return config[arg_name]
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
if hasattr(config, arg_name):
return config[arg_name]
return def_val