asr / toolbox /k2_sherpa /nn_models.py
HoneyTian's picture
update
d03c698
raw
history blame
4.65 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from enum import Enum
from functools import lru_cache
import os
import huggingface_hub
import sherpa
import sherpa_onnx
class EnumDecodingMethod(Enum):
greedy_search = "greedy_search"
modified_beam_search = "modified_beam_search"
model_map = {
"Chinese": [
{
"repo_id": "csukuangfj/wenet-chinese-model",
"nn_model_file": "final.zip",
"tokens_file": "units.txt",
"sub_folder": ".",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28",
"nn_model_file": "model.int8.onnx",
"tokens_file": "tokens.txt",
"sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_paraformer",
}
]
}
def download_model(repo_id: str,
nn_model_file: str,
tokens_file: str,
sub_folder: str,
local_model_dir: str,
):
nn_model_file = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=nn_model_file,
subfolder=sub_folder,
local_dir=local_model_dir,
)
tokens_file = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=tokens_file,
subfolder=sub_folder,
local_dir=local_model_dir,
)
return nn_model_file, tokens_file
def load_sherpa_offline_recognizer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
num_active_paths: int = 2,
decoding_method: str = "greedy_search",
num_mel_bins: int = 80,
frame_dither: int = 0,
):
feat_config = sherpa.FeatureConfig(normalize_samples=False)
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
feat_config.fbank_opts.frame_opts.dither = frame_dither
config = sherpa.OfflineRecognizerConfig(
nn_model=nn_model_file,
tokens=tokens_file,
use_gpu=False,
feat_config=feat_config,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
recognizer = sherpa.OfflineRecognizer(config)
return recognizer
def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
decoding_method: str = "greedy_search",
feature_dim: int = 80,
num_threads: int = 2,
):
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=nn_model_file,
tokens=tokens_file,
num_threads=num_threads,
sample_rate=sample_rate,
feature_dim=feature_dim,
decoding_method=decoding_method,
debug=False,
)
return recognizer
def load_recognizer(repo_id: str,
nn_model_file: str,
tokens_file: str,
sub_folder: str,
local_model_dir: str,
loader: str,
decoding_method: str = "greedy_search",
num_active_paths: int = 4,
):
if not os.path.exists(local_model_dir):
download_model(
repo_id=repo_id,
nn_model_file=nn_model_file,
tokens_file=tokens_file,
sub_folder=sub_folder,
local_model_dir=local_model_dir,
)
if loader == "load_sherpa_offline_recognizer":
recognizer = load_sherpa_offline_recognizer(
nn_model_file=nn_model_file,
tokens_file=tokens_file,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
elif loader == "load_sherpa_offline_recognizer_from_paraformer":
recognizer = load_sherpa_offline_recognizer_from_paraformer(
nn_model_file=nn_model_file,
tokens_file=tokens_file,
decoding_method=decoding_method,
)
else:
raise NotImplementedError("loader not support: {}".format(loader))
return recognizer
if __name__ == "__main__":
pass