OSUM / wenet /bin /recognize.py
tomxxie
适配zeroGPU
568e264
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import torch
import yaml
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.context_graph import ContextGraph
from wenet.utils.ctc_utils import get_blank_id
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--data_type',
default='raw',
# choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--device',
type=str,
default="cpu",
choices=["cpu", "npu", "cuda"],
help='accelerator to use')
parser.add_argument('--dtype',
type=str,
default='fp32',
choices=['fp16', 'fp32', 'bf16'],
help='model\'s dtype')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--beam_size',
type=int,
default=10,
help='beam size for search')
parser.add_argument('--length_penalty',
type=float,
default=0.0,
help='length penalty')
parser.add_argument('--blank_penalty',
type=float,
default=0.0,
help='blank penalty')
parser.add_argument('--result_dir', required=True, help='asr result file')
parser.add_argument('--batch_size',
type=int,
default=16,
help='asr result file')
parser.add_argument('--modes',
nargs='+',
help="""decoding mode, support the following:
attention
ctc_greedy_search
ctc_prefix_beam_search
attention_rescoring
rnnt_greedy_search
rnnt_beam_search
rnnt_beam_attn_rescoring
ctc_beam_td_attn_rescoring
hlg_onebest
hlg_rescore
paraformer_greedy_search
paraformer_beam_search""")
parser.add_argument('--search_ctc_weight',
type=float,
default=1.0,
help='ctc weight for nbest generation')
parser.add_argument('--search_transducer_weight',
type=float,
default=0.0,
help='transducer weight for nbest generation')
parser.add_argument('--ctc_weight',
type=float,
default=0.0,
help='ctc weight for rescoring weight in \
attention rescoring decode mode \
ctc weight for rescoring weight in \
transducer attention rescore decode mode')
parser.add_argument('--transducer_weight',
type=float,
default=0.0,
help='transducer weight for rescoring weight in '
'transducer attention rescore mode')
parser.add_argument('--attn_weight',
type=float,
default=0.0,
help='attention weight for rescoring weight in '
'transducer attention rescore mode')
parser.add_argument('--decoding_chunk_size',
type=int,
default=-1,
help='''decoding chunk size,
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here''')
parser.add_argument('--num_decoding_left_chunks',
type=int,
default=-1,
help='number of left chunks for decoding')
parser.add_argument('--simulate_streaming',
action='store_true',
help='simulate streaming inference')
parser.add_argument('--reverse_weight',
type=float,
default=0.0,
help='''right to left weight for attention rescoring
decode mode''')
parser.add_argument('--override_config',
action='append',
default=[],
help="override yaml config")
parser.add_argument('--word',
default='',
type=str,
help='word file, only used for hlg decode')
parser.add_argument('--hlg',
default='',
type=str,
help='hlg file, only used for hlg decode')
parser.add_argument('--lm_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument('--decoder_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument('--r_decoder_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument(
'--context_bias_mode',
type=str,
default='',
help='''Context bias mode, selectable from the following
option: decoding-graph, deep-biasing''')
parser.add_argument('--context_list_path',
type=str,
default='',
help='Context list path')
parser.add_argument('--context_graph_score',
type=float,
default=0.0,
help='''The higher the score, the greater the degree of
bias using decoding-graph for biasing''')
parser.add_argument('--use_lora',
type=bool,
default=False,
help='''Whether to use lora for biasing''')
parser.add_argument("--lora_ckpt_path",
default=None,
type=str,
help="lora checkpoint path.")
parser.add_argument('--task',
type=str,
default='asr',
help='Context list path')
parser.add_argument('--lang',
type=str,
default='zh',
help='Context list path')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
if args.gpu != -1:
# remain the original usage of gpu
args.device = "cuda"
if "cuda" in args.device:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['filter_conf']['token_max_length'] = 102400
test_conf['filter_conf']['token_min_length'] = 0
test_conf['filter_conf']['max_output_input_ratio'] = 102400
test_conf['filter_conf']['min_output_input_ratio'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['spec_sub'] = False
test_conf['spec_trim'] = False
test_conf['shuffle'] = False
test_conf['sort'] = False
test_conf['cycle'] = 1
test_conf['list_shuffle'] = False
if 'fbank_conf' in test_conf:
test_conf['fbank_conf']['dither'] = 0.0
elif 'mfcc_conf' in test_conf:
test_conf['mfcc_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = args.batch_size
tokenizer = init_tokenizer(configs)
test_dataset = Dataset(args.data_type,
args.test_data,
tokenizer,
test_conf,
partition=False)
test_data_loader = DataLoader(test_dataset,
batch_size=None,
num_workers=args.num_workers)
# Init asr model from configs
args.jit = False
model, configs = init_model(args, configs)
device = torch.device(args.device)
model = model.to(device)
model.eval()
dtype = torch.float32
if args.dtype == 'fp16':
dtype = torch.float16
elif args.dtype == 'bf16':
dtype = torch.bfloat16
logging.info("compute dtype is {}".format(dtype))
context_graph = None
if 'decoding-graph' in args.context_bias_mode:
context_graph = ContextGraph(args.context_list_path,
tokenizer.symbol_table,
configs['tokenizer_conf']['bpe_path'],
args.context_graph_score)
_, blank_id = get_blank_id(configs, tokenizer.symbol_table)
logging.info("blank_id is {}".format(blank_id))
# TODO(Dinghao Zhou): Support RNN-T related decoding
# TODO(Lv Xiang): Support k2 related decoding
# TODO(Kaixun Huang): Support context graph
files = {}
for mode in args.modes:
dir_name = os.path.join(args.result_dir, mode)
os.makedirs(dir_name, exist_ok=True)
file_name = os.path.join(dir_name, 'text')
files[mode] = open(file_name, 'w', encoding='utf-8')
max_format_len = max([len(mode) for mode in args.modes])
with torch.cuda.amp.autocast(enabled=True,
dtype=dtype,
cache_enabled=False):
with torch.no_grad():
utt_num=0
# logging.info(f'utt_num: {utt_num}')
for batch_idx, batch in enumerate(test_data_loader):
keys = batch["keys"]
feats = batch["feats"].to(device)
target = batch["target"].to(device)
feats_lengths = batch["feats_lengths"].to(device)
target_lengths = batch["target_lengths"].to(device)
batch_size = feats.size(0)
# task_list = ["transcribe" for i in range(batch_size)]
task_list = [args.task for i in range(batch_size)]
lang_list = [args.lang for i in range(batch_size)]
infos = {"tasks": task_list, "langs":lang_list}
results = model.decode(
args.modes,
feats,
feats_lengths,
args.beam_size,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
ctc_weight=args.ctc_weight,
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight,
context_graph=context_graph,
blank_id=blank_id,
blank_penalty=args.blank_penalty,
length_penalty=args.length_penalty,
infos=infos)
for i, key in enumerate(keys):
utt_num += 1
for mode, hyps in results.items():
tokens = hyps[i].tokens
line = '{} {}'.format(key,
tokenizer.detokenize(tokens)[0])
logging.info('{} {}'.format(mode.ljust(max_format_len),
line))
files[mode].write(line + '\n')
# if utt_num % 500 == 0:
# files[mode].flush()
for mode, f in files.items():
f.flush() # 强制将缓冲区内容刷新到文件
f.close()
if __name__ == '__main__':
main()