asr / toolbox /k2_sherpa /nn_models.py
HoneyTian's picture
update
e43afda
raw
history blame
21 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from enum import Enum
from functools import lru_cache
import logging
import os
import platform
from pathlib import Path
import huggingface_hub
import sherpa
import sherpa_onnx
main_logger = logging.getLogger("main")
class EnumDecodingMethod(Enum):
greedy_search = "greedy_search"
modified_beam_search = "modified_beam_search"
model_map = {
"Chinese": [
{
"repo_id": "csukuangfj/wenet-chinese-model",
"nn_model_file": "final.zip",
"nn_model_file_sub_folder": ".",
"tokens_file": "units.txt",
"tokens_file_sub_folder": ".",
"normalize_samples": False,
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2024-03-09",
"nn_model_file": "model.int8.onnx",
"nn_model_file_sub_folder": ".",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_paraformer",
},
{
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09",
"nn_model_file": "model.int8.onnx",
"nn_model_file_sub_folder": ".",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_paraformer",
},
{
"repo_id": "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
"nn_model_file": "cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"normalize_samples": True,
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "zrjin/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2",
"encoder_model_file": "encoder-epoch-20-avg-1.onnx",
"encoder_model_file_sub_folder": ".",
"decoder_model_file": "decoder-epoch-20-avg-1.onnx",
"decoder_model_file_sub_folder": ".",
"joiner_model_file": "joiner-epoch-20-avg-1.onnx",
"joiner_model_file_sub_folder": ".",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_transducer",
},
{
"repo_id": "zrjin/icefall-asr-aishell-zipformer-large-2023-10-24",
"encoder_model_file": "encoder-epoch-56-avg-23.onnx",
"encoder_model_file_sub_folder": "exp",
"decoder_model_file": "decoder-epoch-56-avg-23.onnx",
"decoder_model_file_sub_folder": "exp",
"joiner_model_file": "joiner-epoch-56-avg-23.onnx",
"joiner_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"loader": "load_sherpa_offline_recognizer_from_transducer",
},
{
"repo_id": "zrjin/icefall-asr-aishell-zipformer-small-2023-10-24",
"encoder_model_file": "encoder-epoch-55-avg-21.onnx",
"encoder_model_file_sub_folder": "exp",
"decoder_model_file": "decoder-epoch-55-avg-21.onnx",
"decoder_model_file_sub_folder": "exp",
"joiner_model_file": "joiner-epoch-55-avg-21.onnx",
"joiner_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"loader": "load_sherpa_offline_recognizer_from_transducer",
},
{
"repo_id": "zrjin/icefall-asr-aishell-zipformer-2023-10-24",
"encoder_model_file": "encoder-epoch-55-avg-17.onnx",
"encoder_model_file_sub_folder": "exp",
"decoder_model_file": "decoder-epoch-55-avg-17.onnx",
"decoder_model_file_sub_folder": "exp",
"joiner_model_file": "joiner-epoch-55-avg-17.onnx",
"joiner_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"loader": "load_sherpa_offline_recognizer_from_transducer",
},
{
"repo_id": "desh2608/icefall-asr-alimeeting-pruned-transducer-stateless7",
"nn_model_file": "cpu_jit.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"normalize_samples": True,
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12",
"nn_model_file": "cpu_jit.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"normalize_samples": True,
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12",
"nn_model_file": "cpu_jit.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"normalize_samples": True,
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2",
"nn_model_file": "cpu_jit_torch.1.7.1.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"normalize_samples": True,
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "luomingshuang/icefall_asr_alimeeting_pruned_transducer_stateless2",
"nn_model_file": "cpu_jit_torch_1.7.1.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_char",
"normalize_samples": True,
"loader": "load_sherpa_offline_recognizer",
},
],
"English": [
{
"repo_id": "csukuangfj/sherpa-onnx-whisper-tiny.en",
"encoder_model_file": "tiny.en-encoder.int8.onnx",
"encoder_model_file_sub_folder": ".",
"decoder_model_file": "tiny.en-decoder.int8.onnx",
"decoder_model_file_sub_folder": ".",
"tokens_file": "tiny.en-tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_whisper",
},
{
"repo_id": "csukuangfj/sherpa-onnx-whisper-base.en",
"encoder_model_file": "base.en-encoder.int8.onnx",
"encoder_model_file_sub_folder": ".",
"decoder_model_file": "base.en-decoder.int8.onnx",
"decoder_model_file_sub_folder": ".",
"tokens_file": "base.en-tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_whisper",
},
{
"repo_id": "csukuangfj/sherpa-onnx-whisper-small.en",
"encoder_model_file": "small.en-encoder.int8.onnx",
"encoder_model_file_sub_folder": ".",
"decoder_model_file": "small.en-decoder.int8.onnx",
"decoder_model_file_sub_folder": ".",
"tokens_file": "small.en-tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_whisper",
},
{
"repo_id": "csukuangfj/sherpa-onnx-paraformer-en-2024-03-09",
"nn_model_file": "model.int8.onnx",
"nn_model_file_sub_folder": ".",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_paraformer",
},
{
"repo_id": "yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17",
"encoder_model_file": "encoder-epoch-30-avg-9.onnx",
"encoder_model_file_sub_folder": "exp",
"decoder_model_file": "decoder-epoch-30-avg-9.onnx",
"decoder_model_file_sub_folder": "exp",
"joiner_model_file": "joiner-epoch-30-avg-9.onnx",
"joiner_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_bpe_500",
"loader": "load_sherpa_offline_recognizer_from_transducer",
},
{
"repo_id": "wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2",
"nn_model_file": "cpu_jit-iter-3488000-avg-20.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "./giga-tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "yfyeung/icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04",
"nn_model_file": "cpu_jit-epoch-30-avg-4.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_bpe_500",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "yfyeung/icefall-asr-finetune-mux-pruned_transducer_stateless7-2023-05-19",
"nn_model_file": "cpu_jit-epoch-20-avg-5.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_bpe_500",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "WeijiZhuang/icefall-asr-librispeech-pruned-transducer-stateless8-2022-12-02",
"nn_model_file": "cpu_jit-torch-1.10.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_bpe_500",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14",
"nn_model_file": "cpu_jit.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_bpe_500",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11",
"nn_model_file": "cpu_jit-torch-1.10.0.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_bpe_500",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13",
"nn_model_file": "cpu_jit.pt",
"nn_model_file_sub_folder": "exp",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": "data/lang_bpe_500",
"loader": "load_sherpa_offline_recognizer",
},
],
"Chinese+English": [
{
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28",
"nn_model_file": "model.int8.onnx",
"nn_model_file_sub_folder": ".",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_paraformer",
},
],
"Chinese+Cantonese+English": [
{
"repo_id": "csukuangfj/sherpa-onnx-paraformer-trilingual-zh-cantonese-en",
"nn_model_file": "model.int8.onnx",
"nn_model_file_sub_folder": ".",
"tokens_file": "tokens.txt",
"tokens_file_sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_paraformer",
},
]
}
def download_model(local_model_dir: str,
**kwargs,
):
repo_id = kwargs["repo_id"]
if "nn_model_file" in kwargs.keys():
main_logger.info("download nn_model_file. filename: {}, subfolder: {}".format(kwargs["nn_model_file"], kwargs["nn_model_file_sub_folder"]))
_ = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=kwargs["nn_model_file"],
subfolder=kwargs["nn_model_file_sub_folder"],
local_dir=local_model_dir,
)
if "encoder_model_file" in kwargs.keys():
main_logger.info("download encoder_model_file. filename: {}, subfolder: {}".format(kwargs["encoder_model_file"], kwargs["encoder_model_file_sub_folder"]))
_ = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=kwargs["encoder_model_file"],
subfolder=kwargs["encoder_model_file_sub_folder"],
local_dir=local_model_dir,
)
if "decoder_model_file" in kwargs.keys():
main_logger.info("download decoder_model_file. filename: {}, subfolder: {}".format(kwargs["decoder_model_file"], kwargs["decoder_model_file_sub_folder"]))
_ = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=kwargs["decoder_model_file"],
subfolder=kwargs["decoder_model_file_sub_folder"],
local_dir=local_model_dir,
)
if "joiner_model_file" in kwargs.keys():
main_logger.info("download joiner_model_file. filename: {}, subfolder: {}".format(kwargs["joiner_model_file"], kwargs["joiner_model_file_sub_folder"]))
_ = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=kwargs["joiner_model_file"],
subfolder=kwargs["joiner_model_file_sub_folder"],
local_dir=local_model_dir,
)
if "tokens_file" in kwargs.keys():
main_logger.info("download tokens_file. filename: {}, subfolder: {}".format(kwargs["tokens_file"], kwargs["tokens_file_sub_folder"]))
tokens_file = kwargs["tokens_file"]
if not tokens_file.startswith("./"):
_ = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=kwargs["tokens_file"],
subfolder=kwargs["tokens_file_sub_folder"],
local_dir=local_model_dir,
)
def load_sherpa_offline_recognizer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
num_active_paths: int = 2,
decoding_method: str = "greedy_search",
num_mel_bins: int = 80,
frame_dither: int = 0,
normalize_samples: bool = False,
):
feat_config = sherpa.FeatureConfig(normalize_samples=normalize_samples)
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
feat_config.fbank_opts.frame_opts.dither = frame_dither
if not os.path.exists(nn_model_file):
raise AssertionError("nn_model_file not found. nn_model_file: {}".format(nn_model_file))
config = sherpa.OfflineRecognizerConfig(
nn_model=nn_model_file,
tokens=tokens_file,
use_gpu=False,
feat_config=feat_config,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
recognizer = sherpa.OfflineRecognizer(config)
return recognizer
def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
decoding_method: str = "greedy_search",
feature_dim: int = 80,
num_threads: int = 2,
):
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=nn_model_file,
tokens=tokens_file,
num_threads=num_threads,
sample_rate=sample_rate,
feature_dim=feature_dim,
decoding_method=decoding_method,
debug=False,
)
return recognizer
def load_sherpa_offline_recognizer_from_transducer(encoder_model_file: str,
decoder_model_file: str,
joiner_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
decoding_method: str = "greedy_search",
feature_dim: int = 80,
num_threads: int = 2,
num_active_paths: int = 2,
):
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=encoder_model_file,
decoder=decoder_model_file,
joiner=joiner_model_file,
tokens=tokens_file,
num_threads=num_threads,
sample_rate=sample_rate,
feature_dim=feature_dim,
decoding_method=decoding_method,
max_active_paths=num_active_paths,
)
return recognizer
def load_sherpa_offline_recognizer_from_whisper(encoder_model_file: str,
decoder_model_file: str,
tokens_file: str,
num_threads: int = 2,
):
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=encoder_model_file,
decoder=decoder_model_file,
tokens=tokens_file,
num_threads=num_threads,
)
return recognizer
def load_recognizer(local_model_dir: Path,
decoding_method: str = "greedy_search",
num_active_paths: int = 4,
**kwargs
):
if not local_model_dir.exists():
download_model(
local_model_dir=local_model_dir.as_posix(),
**kwargs,
)
loader = kwargs["loader"]
kwargs_ = dict()
if "nn_model_file" in kwargs.keys():
nn_model_file = (local_model_dir / kwargs["nn_model_file_sub_folder"] / kwargs["nn_model_file"]).as_posix()
kwargs_["nn_model_file"] = nn_model_file
if "encoder_model_file" in kwargs.keys():
encoder_model_file = (local_model_dir / kwargs["encoder_model_file_sub_folder"] / kwargs["encoder_model_file"]).as_posix()
kwargs_["encoder_model_file"] = encoder_model_file
if "decoder_model_file" in kwargs.keys():
decoder_model_file = (local_model_dir / kwargs["decoder_model_file_sub_folder"] / kwargs["decoder_model_file"]).as_posix()
kwargs_["decoder_model_file"] = decoder_model_file
if "joiner_model_file" in kwargs.keys():
joiner_model_file = (local_model_dir / kwargs["joiner_model_file_sub_folder"] / kwargs["joiner_model_file"]).as_posix()
kwargs_["joiner_model_file"] = joiner_model_file
if "tokens_file" in kwargs.keys():
tokens_file: str = kwargs["tokens_file"]
if not tokens_file.startswith("./"):
tokens_file = (local_model_dir / kwargs["tokens_file_sub_folder"] / kwargs["tokens_file"]).as_posix()
kwargs_["tokens_file"] = tokens_file
if "normalize_samples" in kwargs.keys():
kwargs_["normalize_samples"] = kwargs["normalize_samples"]
if loader == "load_sherpa_offline_recognizer":
recognizer = load_sherpa_offline_recognizer(
decoding_method=decoding_method,
num_active_paths=num_active_paths,
**kwargs_
)
elif loader == "load_sherpa_offline_recognizer_from_paraformer":
recognizer = load_sherpa_offline_recognizer_from_paraformer(
decoding_method=decoding_method,
**kwargs_
)
elif loader == "load_sherpa_offline_recognizer_from_transducer":
recognizer = load_sherpa_offline_recognizer_from_transducer(
decoding_method=decoding_method,
**kwargs_
)
elif loader == "load_sherpa_offline_recognizer_from_whisper":
recognizer = load_sherpa_offline_recognizer_from_whisper(
**kwargs_
)
else:
raise NotImplementedError("loader not support: {}".format(loader))
return recognizer
if __name__ == "__main__":
pass