# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script is for testing exported onnx encoder and decoder from export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference. It requires a python wrapped c++ ctc decoder. Please install it by following: https://github.com/Slyne/ctc_decoder.git """ from __future__ import print_function import argparse import copy import logging import os import sys import torch import yaml from torch.utils.data import DataLoader from wenet.dataset.dataset import Dataset from wenet.utils.common import IGNORE_ID from wenet.utils.config import override_config from wenet.utils.init_tokenizer import init_tokenizer import onnxruntime as rt import multiprocessing import numpy as np try: from swig_decoders import map_batch, \ ctc_beam_search_decoder_batch, \ TrieVector, PathTrie except ImportError: print('Please install ctc decoders first by refering to\n' + 'https://github.com/Slyne/ctc_decoder.git') sys.exit(1) def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--test_data', required=True, help='test data file') parser.add_argument('--data_type', default='raw', choices=['raw', 'shard'], help='train and cv data type') parser.add_argument('--gpu', type=int, default=-1, help='gpu id for this rank, -1 for cpu') parser.add_argument('--dict', required=True, help='dict file') parser.add_argument('--encoder_onnx', required=True, help='encoder onnx file') parser.add_argument('--decoder_onnx', required=True, help='decoder onnx file') parser.add_argument('--result_file', required=True, help='asr result file') parser.add_argument('--batch_size', type=int, default=32, help='asr result file') parser.add_argument('--mode', choices=[ 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' ], default='attention_rescoring', help='decoding mode') parser.add_argument('--bpe_model', default=None, type=str, help='bpe model for english part') parser.add_argument('--override_config', action='append', default=[], help="override yaml config") parser.add_argument('--fp16', action='store_true', help='whether to export fp16 model, default false') args = parser.parse_args() return args def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) if len(args.override_config) > 0: configs = override_config(configs, args.override_config) reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None) test_conf = copy.deepcopy(configs['dataset_conf']) test_conf['filter_conf']['max_length'] = 102400 test_conf['filter_conf']['min_length'] = 0 test_conf['filter_conf']['token_max_length'] = 102400 test_conf['filter_conf']['token_min_length'] = 0 test_conf['filter_conf']['max_output_input_ratio'] = 102400 test_conf['filter_conf']['min_output_input_ratio'] = 0 test_conf['speed_perturb'] = False test_conf['spec_aug'] = False test_conf['spec_sub'] = False test_conf['spec_trim'] = False test_conf['shuffle'] = False test_conf['sort'] = False test_conf['fbank_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size tokenizer = init_tokenizer(configs) test_dataset = Dataset(args.data_type, args.test_data, tokenizer, test_conf, partition=False) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) # Init asr model from configs use_cuda = args.gpu >= 0 and torch.cuda.is_available() if use_cuda: EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider'] else: EP_list = ['CPUExecutionProvider'] encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list) decoder_ort_session = None if args.mode == "attention_rescoring": decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list) # Load dict vocabulary = [] char_dict = {} with open(args.dict, 'r') as fin: for line in fin: arr = line.strip().split() assert len(arr) == 2 char_dict[int(arr[1])] = arr[0] vocabulary.append(arr[0]) vocab_size = len(char_dict) sos = (vocab_size - 1 if special_tokens is None else special_tokens.get("", vocab_size - 1)) eos = (vocab_size - 1 if special_tokens is None else special_tokens.get("", vocab_size - 1)) with torch.no_grad(), open(args.result_file, 'w') as fout: for _, batch in enumerate(test_data_loader): keys = batch['keys'] feats = batch['feats'] feats_lengths = batch['feats_lengths'] feats, feats_lengths = feats.numpy(), feats_lengths.numpy() if args.fp16: feats = feats.astype(np.float16) ort_inputs = { encoder_ort_session.get_inputs()[0].name: feats, encoder_ort_session.get_inputs()[1].name: feats_lengths } ort_outs = encoder_ort_session.run(None, ort_inputs) encoder_out, encoder_out_lens, ctc_log_probs, \ beam_log_probs, beam_log_probs_idx = ort_outs beam_size = beam_log_probs.shape[-1] batch_size = beam_log_probs.shape[0] num_processes = min(multiprocessing.cpu_count(), batch_size) if args.mode == 'ctc_greedy_search': if beam_size != 1: log_probs_idx = beam_log_probs_idx[:, :, 0] batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) elif args.mode in ('ctc_prefix_beam_search', "attention_rescoring"): batch_log_probs_seq_list = beam_log_probs.tolist() batch_log_probs_idx_list = beam_log_probs_idx.tolist() batch_len_list = encoder_out_lens.tolist() batch_log_probs_seq = [] batch_log_probs_ids = [] batch_start = [] # only effective in streaming deployment batch_root = TrieVector() root_dict = {} for i in range(len(batch_len_list)): num_sent = batch_len_list[i] batch_log_probs_seq.append( batch_log_probs_seq_list[i][0:num_sent]) batch_log_probs_ids.append( batch_log_probs_idx_list[i][0:num_sent]) root_dict[i] = PathTrie() batch_root.append(root_dict[i]) batch_start.append(True) score_hyps = ctc_beam_search_decoder_batch( batch_log_probs_seq, batch_log_probs_ids, batch_root, batch_start, beam_size, num_processes, 0, -2, 0.99999) if args.mode == 'ctc_prefix_beam_search': hyps = [] for cand_hyps in score_hyps: hyps.append(cand_hyps[0][1]) hyps = map_batch(hyps, vocabulary, num_processes, False, 0) if args.mode == 'attention_rescoring': ctc_score, all_hyps = [], [] max_len = 0 for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < beam_size: hyps += (beam_size - cur_len) * [(-float("INF"), (0, ))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) if len(hyp[1]) > max_len: max_len = len(hyp[1]) ctc_score.append(cur_ctc_score) if args.fp16: ctc_score = np.array(ctc_score, dtype=np.float16) else: ctc_score = np.array(ctc_score, dtype=np.float32) hyps_pad_sos_eos = np.ones( (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID r_hyps_pad_sos_eos = np.ones( (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) k = 0 for i in range(batch_size): for j in range(beam_size): cand = all_hyps[k] l = len(cand) + 2 hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [ eos ] hyps_lens_sos[i][j] = len(cand) + 1 k += 1 decoder_ort_inputs = { decoder_ort_session.get_inputs()[0].name: encoder_out, decoder_ort_session.get_inputs()[1].name: encoder_out_lens, decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, decoder_ort_session.get_inputs()[-1].name: ctc_score } if reverse_weight > 0: r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs( )[4].name decoder_ort_inputs[ r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0] best_sents = [] k = 0 for idx in best_index: cur_best_sent = all_hyps[k:k + beam_size][idx] best_sents.append(cur_best_sent) k += beam_size hyps = map_batch(best_sents, vocabulary, num_processes) for i, key in enumerate(keys): content = hyps[i] logging.info('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content)) if __name__ == '__main__': main()