Spaces:
Running
Running
import os | |
from pathlib import Path | |
from typing import Union, Callable | |
import torch | |
import torch.nn as nn | |
from omegaconf import DictConfig, ListConfig | |
from silero.utils import Decoder | |
from AIModels import NeuralASR | |
from ModelInterfaces import IASRModel | |
from constants import MODEL_NAME_DEFAULT, language_not_implemented, app_logger, sample_rate_start, silero_versions_dict | |
default_speaker_dict = { | |
"de": {"speaker": "karlsson", "model_id": "v3_de", "sample_rate": sample_rate_start}, | |
"en": {"speaker": "en_0", "model_id": "v3_en", "sample_rate": sample_rate_start}, | |
} | |
def getASRModel(language: str, model_name: str = MODEL_NAME_DEFAULT) -> IASRModel: | |
"""Wrapper function to get the ASR model based on the model name and language. | |
Currently supported models are 'whisper', 'faster_whisper', and 'silero'. | |
Args: | |
language: str: The language of the model. | |
model_name: str: The name of the model to use. Default is 'whisper'. | |
Returns: | |
IASRModel: The ASR model instance. | |
""" | |
models_dict = { | |
"whisper": __get_model_whisper, | |
"faster_whisper": __get_model_faster_whisper, | |
"silero": __get_model_silero | |
} | |
if model_name in models_dict: | |
fn = models_dict[model_name] | |
return fn(language) | |
models_supported = ", ".join(models_dict.keys()) | |
raise ValueError(f"Model '{model_name}' not implemented. Supported models: {models_supported}.") | |
def __get_model_whisper(language: str) -> IASRModel: | |
from whisper_wrapper import WhisperASRModel | |
return WhisperASRModel(language=language) | |
def __get_model_faster_whisper(language: str) -> IASRModel: | |
from faster_whisper_wrapper import FasterWhisperASRModel | |
return FasterWhisperASRModel(language=language) | |
def __get_model_silero(language: str) -> IASRModel: | |
import tempfile | |
tmp_dir = tempfile.gettempdir() | |
if language == "de": | |
model, decoder, _ = __silero_stt( | |
language="de", version="v4", jit_model="jit_large", output_folder=tmp_dir | |
) | |
return __eval_apply_neural_asr(model, decoder, language) | |
elif language == "en": | |
model, decoder, _ = __silero_stt(language="en", output_folder=tmp_dir) | |
return __eval_apply_neural_asr(model, decoder, language) | |
raise ValueError(language_not_implemented.format(language)) | |
def __eval_apply_neural_asr(model: nn.Module, decoder: Decoder, language: str): | |
app_logger.info(f"LOADED silero model language: {language}, version: '{silero_versions_dict[language]}'") | |
model.eval() | |
app_logger.info(f"EVALUATED silero model language: {language}, version: '{silero_versions_dict[language]}'") | |
return NeuralASR(model, decoder) | |
def getTranslationModel(language: str) -> nn.Module: | |
"""Wrapper function to get the translation model based on the language.""" | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForSeq2SeqLM | |
if language == 'de': | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"Helsinki-NLP/opus-mt-de-en") | |
tokenizer = AutoTokenizer.from_pretrained( | |
"Helsinki-NLP/opus-mt-de-en") | |
# Cache models to avoid Hugging face processing (not needed now) | |
# with open('translation_model_de.pickle', 'wb') as handle: | |
# pickle.dump(model, handle) | |
# with open('translation_tokenizer_de.pickle', 'wb') as handle: | |
# pickle.dump(tokenizer, handle) | |
else: | |
raise ValueError(language_not_implemented.format(language)) | |
return model, tokenizer | |
def __silero_tts(language: str = "en", version: str = "latest", output_folder: Path | str = None, **kwargs) -> tuple[nn.Module, str, int, str, dict, Callable, str]: | |
""" | |
Modified function to create instances of Silero Text-To-Speech Models. | |
Please see https://github.com/snakers4/silero-models?tab=readme-ov-file#text-to-speech for usage examples. | |
language="en", version="latest", output_folder: Path | str = None, **kwargs | |
Args: | |
language (str): Language of the model. Available options are ['ru', 'en', 'de', 'es', 'fr']. Default is 'en'. | |
version (str): Version of the model to use. Default is 'latest'. | |
output_folder (Path | str): Path to the folder where the model will be saved. Default is None. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
tuple: Depending on the model version and the input arguments, returns a tuple containing: | |
- model: The loaded TTS model. | |
- symbols (str): The set of symbols used by the model (only for older model versions). | |
- sample_rate (int): The sample rate of the model. | |
- example_text (str): Example text for the model. | |
- speaker (dict): | |
- apply_tts (function): Function to apply TTS (only for older model versions). | |
- model_id (str): The model ID (only for older model versions). | |
Raises: | |
AssertionError: If the specified language is not in the supported list. | |
""" | |
output_folder = Path(output_folder) | |
current_model_lang = default_speaker_dict[language] | |
app_logger.info(f"model speaker current_model_lang: {current_model_lang} ...") | |
if language in default_speaker_dict: | |
model_id = current_model_lang["model_id"] | |
models = __get_models(language, output_folder, version, model_type="tts_models") | |
available_languages = list(models.tts_models.keys()) | |
assert ( | |
language in available_languages | |
), f"Language not in the supported list {available_languages}" | |
tts_models_lang = models.tts_models[language] | |
model_conf = tts_models_lang[model_id] | |
model_conf_latest = model_conf[version] | |
app_logger.info(f"model_conf: {model_conf_latest} ...") | |
if "_v2" in model_id or "_v3" in model_id or "v3_" in model_id or "v4_" in model_id: | |
from torch import package | |
model_url = model_conf_latest.package | |
model_dir = output_folder / "model" | |
os.makedirs(model_dir, exist_ok=True) | |
model_path = output_folder / os.path.basename(model_url) | |
if not os.path.isfile(model_path): | |
torch.hub.download_url_to_file(model_url, model_path, progress=True) | |
imp = package.PackageImporter(model_path) | |
model = imp.load_pickle("tts_models", "model") | |
app_logger.info( | |
f"current model_conf_latest.sample_rate:{model_conf_latest.sample_rate} ..." | |
) | |
sample_rate = current_model_lang["sample_rate"] | |
return ( | |
model, | |
model_conf_latest.example, | |
current_model_lang["speaker"], | |
sample_rate, | |
) | |
else: | |
from silero.tts_utils import apply_tts, init_jit_model as init_jit_model_tts | |
model = init_jit_model_tts(model_conf_latest.jit) | |
symbols = model_conf_latest.tokenset | |
example_text = model_conf_latest.example | |
sample_rate = model_conf_latest.sample_rate | |
return model, symbols, sample_rate, example_text, apply_tts, model_id | |
def __get_models(language: str, output_folder: str | Path, version: str, model_type: str) -> Union[DictConfig, ListConfig]: | |
""" | |
Retrieve and load the model configuration for a specified language and model type. | |
Args: | |
language (str): The language for which the model is required. | |
output_folder (str or Path): The folder where the model configuration file should be saved | |
version (str): The version of the model. | |
model_type (str): The type of the model. | |
Returns: | |
OmegaConf: The loaded model configuration. | |
Raises: | |
AssertionError: If the model configuration file does not exist after attempting to download it. | |
Notes: | |
If the model configuration file does not exist in the specified output folder, it will be downloaded | |
from a predefined URL and saved in the output folder. | |
""" | |
from omegaconf import OmegaConf | |
output_folder = ( | |
Path(output_folder) | |
if output_folder is not None | |
else Path(os.path.dirname(__file__)).parent.parent | |
) | |
models_list_file = output_folder / f"latest_silero_model_{language}.yml" | |
app_logger.info(f"models_list_file:{models_list_file}.") | |
if not os.path.exists(models_list_file): | |
app_logger.info( | |
f"model {model_type} yml for '{language}' language, '{version}' version not found, download it in folder {output_folder}..." | |
) | |
torch.hub.download_url_to_file( | |
"https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml", | |
str(models_list_file), | |
progress=False, | |
) | |
assert os.path.exists(models_list_file) | |
return OmegaConf.load(models_list_file) | |
def __get_latest_stt_model(language: str, output_folder: str | Path, version: str, model_type: str, jit_model: str, **kwargs) -> tuple[nn.Module, Decoder]: | |
""" | |
Retrieve the latest Speech-to-Text (STT) model for a given language and model type. | |
Args: | |
language (str): The language for which the STT model is required. | |
output_folder (str): The directory where the model will be saved. | |
version (str): The version of the model to retrieve. | |
model_type (str): The type of the model (e.g., 'large', 'small'). | |
jit_model (str): The specific JIT model to use. | |
**kwargs: Additional keyword arguments to pass to the model initialization function. | |
Returns: | |
tuple: A tuple containing the model and the decoder. | |
Raises: | |
AssertionError: If the specified language is not available in the model type. | |
""" | |
models = __get_models(language, output_folder, version, model_type) | |
available_languages = list(models[model_type].keys()) | |
assert language in available_languages | |
model, decoder = init_jit_model( | |
model_url=models[model_type].get(language).get(version).get(jit_model), | |
output_folder=output_folder, | |
**kwargs, | |
) | |
return model, decoder | |
def init_jit_model( | |
model_url: str, | |
device: torch.device = torch.device("cpu"), | |
output_folder: Path | str = None, | |
) -> tuple[torch.nn.Module, Decoder]: | |
""" | |
Initialize a JIT model from a given URL. | |
Args: | |
model_url (str): The URL to download the model from. | |
device (torch.device, optional): The device to load the model on. Defaults to CPU. | |
output_folder (Path | str, optional): The folder to save the downloaded model. | |
If None, defaults to a 'model' directory in the current file's directory. | |
Returns: | |
Tuple[torch.jit.ScriptModule, Decoder]: The loaded JIT model and its corresponding decoder. | |
""" | |
torch.set_grad_enabled(False) | |
app_logger.info( | |
f"model output_folder exists? '{output_folder is None}' => '{output_folder}' ..." | |
) | |
model_dir = ( | |
Path(output_folder) | |
if output_folder is not None | |
else Path(torch.hub.get_dir()) | |
) | |
os.makedirs(model_dir, exist_ok=True) | |
app_logger.info(f"downloading the models to model_dir: '{model_dir}' ...") | |
model_path = model_dir / os.path.basename(model_url) | |
app_logger.info( | |
f"model_path exists? '{os.path.isfile(model_path)}' => '{model_path}' ..." | |
) | |
if not os.path.isfile(model_path): | |
app_logger.info(f"downloading model_path: '{model_path}' ...") | |
torch.hub.download_url_to_file(model_url, str(model_path), progress=True) | |
app_logger.info(f"model_path {model_path} downloaded!") | |
model = torch.jit.load(model_path, map_location=device) | |
model.eval() | |
return model, Decoder(model.labels) | |
def __silero_stt( | |
language: str = "en", | |
version: str = "latest", | |
jit_model: str = "jit", | |
output_folder: Path | str = None, | |
**kwargs, | |
) -> tuple[nn.Module, Decoder, set[Callable, Callable, Callable, Callable]]: | |
""" | |
Modified function to create instances of Silero Speech-To-Text Model(s). | |
Please see https://github.com/snakers4/silero-models?tab=readme-ov-file#speech-to-text for usage examples. | |
Args: | |
language (str): Language of the model. Available options are ['en', 'de', 'es']. | |
version (str): Version of the model to use. Default is "latest". | |
jit_model (str): Type of JIT model to use. Default is "jit". | |
output_folder (Path | str, optional): Output folder needed in case of docker build. Default is None. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
tuple: A tuple containing the model, decoder object, and a set of utility functions. | |
""" | |
from silero.utils import ( | |
read_audio, | |
read_batch, | |
split_into_batches, | |
prepare_model_input, | |
) | |
model, decoder = __get_latest_stt_model( | |
language, | |
output_folder, | |
version, | |
model_type="stt_models", | |
jit_model=jit_model, | |
**kwargs, | |
) | |
utils = (read_batch, split_into_batches, read_audio, prepare_model_input) | |
return model, decoder, utils | |