Spaces:
Running
Running
File size: 4,649 Bytes
2267fac d03c698 2267fac 168b5c0 2267fac f0ff987 d03c698 2267fac 168b5c0 2267fac 168b5c0 2267fac 168b5c0 2267fac 168b5c0 2267fac 168b5c0 2267fac 168b5c0 2267fac 8dc832e 4281a4a 2267fac 3194abe 2267fac 8dc832e 2267fac 26dfa9a 2267fac d03c698 168b5c0 2267fac d03c698 3194abe 168b5c0 2267fac 168b5c0 2267fac d03c698 168b5c0 d03c698 168b5c0 d03c698 168b5c0 2267fac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from enum import Enum
from functools import lru_cache
import os
import huggingface_hub
import sherpa
import sherpa_onnx
class EnumDecodingMethod(Enum):
greedy_search = "greedy_search"
modified_beam_search = "modified_beam_search"
model_map = {
"Chinese": [
{
"repo_id": "csukuangfj/wenet-chinese-model",
"nn_model_file": "final.zip",
"tokens_file": "units.txt",
"sub_folder": ".",
"loader": "load_sherpa_offline_recognizer",
},
{
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28",
"nn_model_file": "model.int8.onnx",
"tokens_file": "tokens.txt",
"sub_folder": ".",
"loader": "load_sherpa_offline_recognizer_from_paraformer",
}
]
}
def download_model(repo_id: str,
nn_model_file: str,
tokens_file: str,
sub_folder: str,
local_model_dir: str,
):
nn_model_file = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=nn_model_file,
subfolder=sub_folder,
local_dir=local_model_dir,
)
tokens_file = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=tokens_file,
subfolder=sub_folder,
local_dir=local_model_dir,
)
return nn_model_file, tokens_file
def load_sherpa_offline_recognizer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
num_active_paths: int = 2,
decoding_method: str = "greedy_search",
num_mel_bins: int = 80,
frame_dither: int = 0,
):
feat_config = sherpa.FeatureConfig(normalize_samples=False)
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
feat_config.fbank_opts.frame_opts.dither = frame_dither
config = sherpa.OfflineRecognizerConfig(
nn_model=nn_model_file,
tokens=tokens_file,
use_gpu=False,
feat_config=feat_config,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
recognizer = sherpa.OfflineRecognizer(config)
return recognizer
def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
decoding_method: str = "greedy_search",
feature_dim: int = 80,
num_threads: int = 2,
):
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=nn_model_file,
tokens=tokens_file,
num_threads=num_threads,
sample_rate=sample_rate,
feature_dim=feature_dim,
decoding_method=decoding_method,
debug=False,
)
return recognizer
def load_recognizer(repo_id: str,
nn_model_file: str,
tokens_file: str,
sub_folder: str,
local_model_dir: str,
loader: str,
decoding_method: str = "greedy_search",
num_active_paths: int = 4,
):
if not os.path.exists(local_model_dir):
download_model(
repo_id=repo_id,
nn_model_file=nn_model_file,
tokens_file=tokens_file,
sub_folder=sub_folder,
local_model_dir=local_model_dir,
)
if loader == "load_sherpa_offline_recognizer":
recognizer = load_sherpa_offline_recognizer(
nn_model_file=nn_model_file,
tokens_file=tokens_file,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
elif loader == "load_sherpa_offline_recognizer_from_paraformer":
recognizer = load_sherpa_offline_recognizer_from_paraformer(
nn_model_file=nn_model_file,
tokens_file=tokens_file,
decoding_method=decoding_method,
)
else:
raise NotImplementedError("loader not support: {}".format(loader))
return recognizer
if __name__ == "__main__":
pass
|