OSUM / wenet /bin /recognize_onnx_gpu.py
tomxxie
适配zeroGPU
568e264
raw
history blame
13.1 kB
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for testing exported onnx encoder and decoder from
export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference.
It requires a python wrapped c++ ctc decoder.
Please install it by following:
https://github.com/Slyne/ctc_decoder.git
"""
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import torch
import yaml
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.utils.common import IGNORE_ID
from wenet.utils.config import override_config
from wenet.utils.init_tokenizer import init_tokenizer
import onnxruntime as rt
import multiprocessing
import numpy as np
try:
from swig_decoders import map_batch, \
ctc_beam_search_decoder_batch, \
TrieVector, PathTrie
except ImportError:
print('Please install ctc decoders first by refering to\n' +
'https://github.com/Slyne/ctc_decoder.git')
sys.exit(1)
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--data_type',
default='raw',
choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--dict', required=True, help='dict file')
parser.add_argument('--encoder_onnx',
required=True,
help='encoder onnx file')
parser.add_argument('--decoder_onnx',
required=True,
help='decoder onnx file')
parser.add_argument('--result_file', required=True, help='asr result file')
parser.add_argument('--batch_size',
type=int,
default=32,
help='asr result file')
parser.add_argument('--mode',
choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search',
'attention_rescoring'
],
default='attention_rescoring',
help='decoding mode')
parser.add_argument('--bpe_model',
default=None,
type=str,
help='bpe model for english part')
parser.add_argument('--override_config',
action='append',
default=[],
help="override yaml config")
parser.add_argument('--fp16',
action='store_true',
help='whether to export fp16 model, default false')
args = parser.parse_args()
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['filter_conf']['token_max_length'] = 102400
test_conf['filter_conf']['token_min_length'] = 0
test_conf['filter_conf']['max_output_input_ratio'] = 102400
test_conf['filter_conf']['min_output_input_ratio'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['spec_sub'] = False
test_conf['spec_trim'] = False
test_conf['shuffle'] = False
test_conf['sort'] = False
test_conf['fbank_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = args.batch_size
tokenizer = init_tokenizer(configs)
test_dataset = Dataset(args.data_type,
args.test_data,
tokenizer,
test_conf,
partition=False)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
# Init asr model from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
if use_cuda:
EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
EP_list = ['CPUExecutionProvider']
encoder_ort_session = rt.InferenceSession(args.encoder_onnx,
providers=EP_list)
decoder_ort_session = None
if args.mode == "attention_rescoring":
decoder_ort_session = rt.InferenceSession(args.decoder_onnx,
providers=EP_list)
# Load dict
vocabulary = []
char_dict = {}
with open(args.dict, 'r') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
vocabulary.append(arr[0])
vocab_size = len(char_dict)
sos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<sos>", vocab_size - 1))
eos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<eos>", vocab_size - 1))
with torch.no_grad(), open(args.result_file, 'w') as fout:
for _, batch in enumerate(test_data_loader):
keys = batch['keys']
feats = batch['feats']
feats_lengths = batch['feats_lengths']
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
if args.fp16:
feats = feats.astype(np.float16)
ort_inputs = {
encoder_ort_session.get_inputs()[0].name: feats,
encoder_ort_session.get_inputs()[1].name: feats_lengths
}
ort_outs = encoder_ort_session.run(None, ort_inputs)
encoder_out, encoder_out_lens, ctc_log_probs, \
beam_log_probs, beam_log_probs_idx = ort_outs
beam_size = beam_log_probs.shape[-1]
batch_size = beam_log_probs.shape[0]
num_processes = min(multiprocessing.cpu_count(), batch_size)
if args.mode == 'ctc_greedy_search':
if beam_size != 1:
log_probs_idx = beam_log_probs_idx[:, :, 0]
batch_sents = []
for idx, seq in enumerate(log_probs_idx):
batch_sents.append(seq[0:encoder_out_lens[idx]].tolist())
hyps = map_batch(batch_sents, vocabulary, num_processes, True,
0)
elif args.mode in ('ctc_prefix_beam_search',
"attention_rescoring"):
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = [] # only effective in streaming deployment
batch_root = TrieVector()
root_dict = {}
for i in range(len(batch_len_list)):
num_sent = batch_len_list[i]
batch_log_probs_seq.append(
batch_log_probs_seq_list[i][0:num_sent])
batch_log_probs_ids.append(
batch_log_probs_idx_list[i][0:num_sent])
root_dict[i] = PathTrie()
batch_root.append(root_dict[i])
batch_start.append(True)
score_hyps = ctc_beam_search_decoder_batch(
batch_log_probs_seq, batch_log_probs_ids, batch_root,
batch_start, beam_size, num_processes, 0, -2, 0.99999)
if args.mode == 'ctc_prefix_beam_search':
hyps = []
for cand_hyps in score_hyps:
hyps.append(cand_hyps[0][1])
hyps = map_batch(hyps, vocabulary, num_processes, False, 0)
if args.mode == 'attention_rescoring':
ctc_score, all_hyps = [], []
max_len = 0
for hyps in score_hyps:
cur_len = len(hyps)
if len(hyps) < beam_size:
hyps += (beam_size - cur_len) * [(-float("INF"),
(0, ))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
if len(hyp[1]) > max_len:
max_len = len(hyp[1])
ctc_score.append(cur_ctc_score)
if args.fp16:
ctc_score = np.array(ctc_score, dtype=np.float16)
else:
ctc_score = np.array(ctc_score, dtype=np.float32)
hyps_pad_sos_eos = np.ones(
(batch_size, beam_size, max_len + 2),
dtype=np.int64) * IGNORE_ID
r_hyps_pad_sos_eos = np.ones(
(batch_size, beam_size, max_len + 2),
dtype=np.int64) * IGNORE_ID
hyps_lens_sos = np.ones((batch_size, beam_size),
dtype=np.int32)
k = 0
for i in range(batch_size):
for j in range(beam_size):
cand = all_hyps[k]
l = len(cand) + 2
hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos]
r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [
eos
]
hyps_lens_sos[i][j] = len(cand) + 1
k += 1
decoder_ort_inputs = {
decoder_ort_session.get_inputs()[0].name: encoder_out,
decoder_ort_session.get_inputs()[1].name: encoder_out_lens,
decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos,
decoder_ort_session.get_inputs()[3].name: hyps_lens_sos,
decoder_ort_session.get_inputs()[-1].name: ctc_score
}
if reverse_weight > 0:
r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs(
)[4].name
decoder_ort_inputs[
r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos
best_index = decoder_ort_session.run(None,
decoder_ort_inputs)[0]
best_sents = []
k = 0
for idx in best_index:
cur_best_sent = all_hyps[k:k + beam_size][idx]
best_sents.append(cur_best_sent)
k += beam_size
hyps = map_batch(best_sents, vocabulary, num_processes)
for i, key in enumerate(keys):
content = hyps[i]
logging.info('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))
if __name__ == '__main__':
main()