OSUM / wenet /bin /recognize4llmasr.py
tomxxie
适配zeroGPU
568e264
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import torch
import yaml
from gxl_ai_utils.utils.utils_model import set_random_seed
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.llm_asr.llmasr_model import LLMASR_Model
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.context_graph import ContextGraph
from wenet.utils.ctc_utils import get_blank_id
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--data_type',
default='raw',
# choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--device',
type=str,
default="cpu",
choices=["cpu", "npu", "cuda"],
help='accelerator to use')
parser.add_argument('--dtype',
type=str,
default='fp32',
choices=['fp16', 'fp32', 'bf16'],
help='model\'s dtype')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--beam_size',
type=int,
default=10,
help='beam size for search')
parser.add_argument('--length_penalty',
type=float,
default=0.0,
help='length penalty')
parser.add_argument('--blank_penalty',
type=float,
default=0.0,
help='blank penalty')
parser.add_argument('--result_dir', required=True, help='asr result file')
parser.add_argument('--batch_size',
type=int,
default=16,
help='asr result file')
parser.add_argument('--modes',
nargs='+',
help="""decoding mode, support the following:
attention
ctc_greedy_search
ctc_prefix_beam_search
attention_rescoring
rnnt_greedy_search
rnnt_beam_search
rnnt_beam_attn_rescoring
ctc_beam_td_attn_rescoring
hlg_onebest
hlg_rescore
paraformer_greedy_search
paraformer_beam_search""")
parser.add_argument('--search_ctc_weight',
type=float,
default=1.0,
help='ctc weight for nbest generation')
parser.add_argument('--search_transducer_weight',
type=float,
default=0.0,
help='transducer weight for nbest generation')
parser.add_argument('--ctc_weight',
type=float,
default=0.0,
help='ctc weight for rescoring weight in \
attention rescoring decode mode \
ctc weight for rescoring weight in \
transducer attention rescore decode mode')
parser.add_argument('--transducer_weight',
type=float,
default=0.0,
help='transducer weight for rescoring weight in '
'transducer attention rescore mode')
parser.add_argument('--attn_weight',
type=float,
default=0.0,
help='attention weight for rescoring weight in '
'transducer attention rescore mode')
parser.add_argument('--decoding_chunk_size',
type=int,
default=-1,
help='''decoding chunk size,
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here''')
parser.add_argument('--num_decoding_left_chunks',
type=int,
default=-1,
help='number of left chunks for decoding')
parser.add_argument('--simulate_streaming',
action='store_true',
help='simulate streaming inference')
parser.add_argument('--reverse_weight',
type=float,
default=0.0,
help='''right to left weight for attention rescoring
decode mode''')
parser.add_argument('--override_config',
action='append',
default=[],
help="override yaml config")
parser.add_argument('--word',
default='',
type=str,
help='word file, only used for hlg decode')
parser.add_argument('--hlg',
default='',
type=str,
help='hlg file, only used for hlg decode')
parser.add_argument('--lm_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument('--decoder_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument('--r_decoder_scale',
type=float,
default=0.0,
help='lm scale for hlg attention rescore decode')
parser.add_argument(
'--context_bias_mode',
type=str,
default='',
help='''Context bias mode, selectable from the following
option: decoding-graph, deep-biasing''')
parser.add_argument('--context_list_path',
type=str,
default='',
help='Context list path')
parser.add_argument('--context_graph_score',
type=float,
default=0.0,
help='''The higher the score, the greater the degree of
bias using decoding-graph for biasing''')
parser.add_argument('--use_lora',
type=bool,
default=False,
help='''Whether to use lora for biasing''')
parser.add_argument("--lora_ckpt_path",
default=None,
type=str,
help="lora checkpoint path.")
parser.add_argument('--task',
type=str,
default='asr',
help='Context list path')
parser.add_argument('--lang',
type=str,
default='zh',
help='Context list path')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
set_random_seed(777)
if args.gpu != -1:
# remain the original usage of gpu
args.device = "cuda"
if "cuda" in args.device:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['filter_conf']['token_max_length'] = 102400
test_conf['filter_conf']['token_min_length'] = 0
test_conf['filter_conf']['max_output_input_ratio'] = 102400
test_conf['filter_conf']['min_output_input_ratio'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['spec_sub'] = False
test_conf['spec_trim'] = False
test_conf['shuffle'] = True
test_conf['sort'] = False
test_conf['cycle'] = 1
test_conf['list_shuffle'] = True
if 'fbank_conf' in test_conf:
test_conf['fbank_conf']['dither'] = 0.0
elif 'mfcc_conf' in test_conf:
test_conf['mfcc_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = 1
test_conf['split_num'] = 1
tokenizer = init_tokenizer(configs)
test_dataset = Dataset(args.data_type,
args.test_data,
tokenizer,
test_conf,
partition=False)
test_data_loader = DataLoader(test_dataset,
batch_size=None,
num_workers=args.num_workers)
# Init asr model from configs
args.jit = False
model, configs = init_model(args, configs)
device = torch.device(args.device)
model:LLMASR_Model = model.to(device)
model.eval()
dtype = torch.float32
if args.dtype == 'fp16':
dtype = torch.float16
elif args.dtype == 'bf16':
dtype = torch.bfloat16
logging.info("compute dtype is {}".format(dtype))
context_graph = None
if 'decoding-graph' in args.context_bias_mode:
context_graph = ContextGraph(args.context_list_path,
tokenizer.symbol_table,
configs['tokenizer_conf']['bpe_path'],
args.context_graph_score)
_, blank_id = get_blank_id(configs, tokenizer.symbol_table)
logging.info("blank_id is {}".format(blank_id))
# TODO(Dinghao Zhou): Support RNN-T related decoding
# TODO(Lv Xiang): Support k2 related decoding
# TODO(Kaixun Huang): Support context graph
files = {}
modes = ['llmasr_decode']
for mode in modes:
dir_name = os.path.join(args.result_dir, mode)
os.makedirs(dir_name, exist_ok=True)
file_name = os.path.join(dir_name, 'text')
files[mode] = open(file_name, 'w', encoding='utf-8')
max_format_len = max([len(mode) for mode in args.modes])
# Get prompt config
from gxl_ai_utils.utils import utils_file
global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml')
with torch.cuda.amp.autocast(enabled=True,
dtype=dtype,
cache_enabled=False):
with torch.no_grad():
# logging.info(f'utt_num: {utt_num}')
for batch_idx, batch in enumerate(test_data_loader):
keys = batch["keys"]
feats = batch["feats"].to(device)
target = batch["target"].to(device)
feats_lengths = batch["feats_lengths"].to(device)
target_lengths = batch["target_lengths"].to(device)
batch_size = feats.size(0)
import random
if '><' in args.task:
args.task = args.task.replace('><', '> <')
if args.task == "<TRANSCRIBE>" or args.task == "<transcribe>":
is_truncation = False
else:
is_truncation = True
random_index = random.randint(0, len(global_prompt_dict[args.task])-1)
prompt = global_prompt_dict[args.task][random_index]
# print(args.task, prompt)
res_text = model.generate(wavs=feats, wavs_len=feats_lengths, prompt=prompt)
for mode in modes:
line = "{}\t{}".format(keys[0], res_text[0])
files[mode].write(line+'\n')
utils_file.logging_print( '{} {} {}'.format(batch_idx, keys[0], res_text[0]))
if batch_idx % 100 == 0:
for mode, f in files.items():
f.flush() # 强制将缓冲区内容刷新到文件
# if batch_idx >= 1000 and is_truncation:
# utils_file.logging_info('采用截断至3000的策略')
# break
for mode, f in files.items():
f.flush() # 强制将缓冲区内容刷新到文件
f.close()
if __name__ == '__main__':
main()