Spaces:
Sleeping
Sleeping
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
from __future__ import print_function | |
import argparse | |
import copy | |
import logging | |
import os | |
import sys | |
import torch | |
import yaml | |
from torch.utils.data import DataLoader | |
from wenet.dataset.dataset import Dataset | |
from wenet.utils.common import IGNORE_ID | |
from wenet.utils.file_utils import read_symbol_table | |
from wenet.utils.config import override_config | |
import onnxruntime as rt | |
import multiprocessing | |
import numpy as np | |
try: | |
from swig_decoders import ( | |
map_batch, | |
ctc_beam_search_decoder_batch, | |
TrieVector, | |
PathTrie, | |
) | |
except ImportError: | |
print( | |
"Please install ctc decoders first by refering to\n" | |
+ "https://github.com/Slyne/ctc_decoder.git" | |
) | |
sys.exit(1) | |
def get_args(): | |
parser = argparse.ArgumentParser(description="recognize with your model") | |
parser.add_argument("--config", required=True, help="config file") | |
parser.add_argument("--test_data", required=True, help="test data file") | |
parser.add_argument( | |
"--data_type", | |
default="raw", | |
choices=["raw", "shard"], | |
help="train and cv data type", | |
) | |
parser.add_argument( | |
"--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" | |
) | |
parser.add_argument("--dict", required=True, help="dict file") | |
parser.add_argument("--encoder_onnx", required=True, help="encoder onnx file") | |
parser.add_argument("--decoder_onnx", required=True, help="decoder onnx file") | |
parser.add_argument("--result_file", required=True, help="asr result file") | |
parser.add_argument("--batch_size", type=int, default=32, help="asr result file") | |
parser.add_argument( | |
"--mode", | |
choices=["ctc_greedy_search", "ctc_prefix_beam_search", "attention_rescoring"], | |
default="attention_rescoring", | |
help="decoding mode", | |
) | |
parser.add_argument( | |
"--bpe_model", default=None, type=str, help="bpe model for english part" | |
) | |
parser.add_argument( | |
"--override_config", action="append", default=[], help="override yaml config" | |
) | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
help="whether to export fp16 model, default false", | |
) | |
args = parser.parse_args() | |
print(args) | |
return args | |
def main(): | |
args = get_args() | |
logging.basicConfig( | |
level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" | |
) | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) | |
with open(args.config, "r") as fin: | |
configs = yaml.load(fin, Loader=yaml.FullLoader) | |
if len(args.override_config) > 0: | |
configs = override_config(configs, args.override_config) | |
reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) | |
symbol_table = read_symbol_table(args.dict) | |
test_conf = copy.deepcopy(configs["dataset_conf"]) | |
test_conf["filter_conf"]["max_length"] = 102400 | |
test_conf["filter_conf"]["min_length"] = 0 | |
test_conf["filter_conf"]["token_max_length"] = 102400 | |
test_conf["filter_conf"]["token_min_length"] = 0 | |
test_conf["filter_conf"]["max_output_input_ratio"] = 102400 | |
test_conf["filter_conf"]["min_output_input_ratio"] = 0 | |
test_conf["speed_perturb"] = False | |
test_conf["spec_aug"] = False | |
test_conf["spec_sub"] = False | |
test_conf["spec_trim"] = False | |
test_conf["shuffle"] = False | |
test_conf["sort"] = False | |
test_conf["fbank_conf"]["dither"] = 0.0 | |
test_conf["batch_conf"]["batch_type"] = "static" | |
test_conf["batch_conf"]["batch_size"] = args.batch_size | |
test_dataset = Dataset( | |
args.data_type, | |
args.test_data, | |
symbol_table, | |
test_conf, | |
args.bpe_model, | |
partition=False, | |
) | |
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) | |
# Init asr model from configs | |
use_cuda = args.gpu >= 0 and torch.cuda.is_available() | |
if use_cuda: | |
EP_list = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
else: | |
EP_list = ["CPUExecutionProvider"] | |
encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list) | |
decoder_ort_session = None | |
if args.mode == "attention_rescoring": | |
decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list) | |
# Load dict | |
vocabulary = [] | |
char_dict = {} | |
with open(args.dict, "r") as fin: | |
for line in fin: | |
arr = line.strip().split() | |
assert len(arr) == 2 | |
char_dict[int(arr[1])] = arr[0] | |
vocabulary.append(arr[0]) | |
eos = sos = len(char_dict) - 1 | |
with torch.no_grad(), open(args.result_file, "w") as fout: | |
for _, batch in enumerate(test_data_loader): | |
keys, feats, _, feats_lengths, _ = batch | |
feats, feats_lengths = feats.numpy(), feats_lengths.numpy() | |
if args.fp16: | |
feats = feats.astype(np.float16) | |
ort_inputs = { | |
encoder_ort_session.get_inputs()[0].name: feats, | |
encoder_ort_session.get_inputs()[1].name: feats_lengths, | |
} | |
ort_outs = encoder_ort_session.run(None, ort_inputs) | |
( | |
encoder_out, | |
encoder_out_lens, | |
ctc_log_probs, | |
beam_log_probs, | |
beam_log_probs_idx, | |
) = ort_outs | |
beam_size = beam_log_probs.shape[-1] | |
batch_size = beam_log_probs.shape[0] | |
num_processes = min(multiprocessing.cpu_count(), batch_size) | |
if args.mode == "ctc_greedy_search": | |
if beam_size != 1: | |
log_probs_idx = beam_log_probs_idx[:, :, 0] | |
batch_sents = [] | |
for idx, seq in enumerate(log_probs_idx): | |
batch_sents.append(seq[0 : encoder_out_lens[idx]].tolist()) | |
hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) | |
elif args.mode in ("ctc_prefix_beam_search", "attention_rescoring"): | |
batch_log_probs_seq_list = beam_log_probs.tolist() | |
batch_log_probs_idx_list = beam_log_probs_idx.tolist() | |
batch_len_list = encoder_out_lens.tolist() | |
batch_log_probs_seq = [] | |
batch_log_probs_ids = [] | |
batch_start = [] # only effective in streaming deployment | |
batch_root = TrieVector() | |
root_dict = {} | |
for i in range(len(batch_len_list)): | |
num_sent = batch_len_list[i] | |
batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) | |
batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) | |
root_dict[i] = PathTrie() | |
batch_root.append(root_dict[i]) | |
batch_start.append(True) | |
score_hyps = ctc_beam_search_decoder_batch( | |
batch_log_probs_seq, | |
batch_log_probs_ids, | |
batch_root, | |
batch_start, | |
beam_size, | |
num_processes, | |
0, | |
-2, | |
0.99999, | |
) | |
if args.mode == "ctc_prefix_beam_search": | |
hyps = [] | |
for cand_hyps in score_hyps: | |
hyps.append(cand_hyps[0][1]) | |
hyps = map_batch(hyps, vocabulary, num_processes, False, 0) | |
if args.mode == "attention_rescoring": | |
ctc_score, all_hyps = [], [] | |
max_len = 0 | |
for hyps in score_hyps: | |
cur_len = len(hyps) | |
if len(hyps) < beam_size: | |
hyps += (beam_size - cur_len) * [(-float("INF"), (0,))] | |
cur_ctc_score = [] | |
for hyp in hyps: | |
cur_ctc_score.append(hyp[0]) | |
all_hyps.append(list(hyp[1])) | |
if len(hyp[1]) > max_len: | |
max_len = len(hyp[1]) | |
ctc_score.append(cur_ctc_score) | |
if args.fp16: | |
ctc_score = np.array(ctc_score, dtype=np.float16) | |
else: | |
ctc_score = np.array(ctc_score, dtype=np.float32) | |
hyps_pad_sos_eos = ( | |
np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) | |
* IGNORE_ID | |
) | |
r_hyps_pad_sos_eos = ( | |
np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) | |
* IGNORE_ID | |
) | |
hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) | |
k = 0 | |
for i in range(batch_size): | |
for j in range(beam_size): | |
cand = all_hyps[k] | |
l = len(cand) + 2 | |
hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] | |
r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos] | |
hyps_lens_sos[i][j] = len(cand) + 1 | |
k += 1 | |
decoder_ort_inputs = { | |
decoder_ort_session.get_inputs()[0].name: encoder_out, | |
decoder_ort_session.get_inputs()[1].name: encoder_out_lens, | |
decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, | |
decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, | |
decoder_ort_session.get_inputs()[-1].name: ctc_score, | |
} | |
if reverse_weight > 0: | |
r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name | |
decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos | |
best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0] | |
best_sents = [] | |
k = 0 | |
for idx in best_index: | |
cur_best_sent = all_hyps[k : k + beam_size][idx] | |
best_sents.append(cur_best_sent) | |
k += beam_size | |
hyps = map_batch(best_sents, vocabulary, num_processes) | |
for i, key in enumerate(keys): | |
content = hyps[i] | |
logging.info("{} {}".format(key, content)) | |
fout.write("{} {}\n".format(key, content)) | |
if __name__ == "__main__": | |
main() | |