# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import argparse import copy import logging import os import torch import yaml from gxl_ai_utils.utils.utils_model import set_random_seed from torch.utils.data import DataLoader from wenet.dataset.dataset import Dataset from wenet.llm_asr.llmasr_model import LLMASR_Model from wenet.utils.config import override_config from wenet.utils.init_model import init_model from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.context_graph import ContextGraph from wenet.utils.ctc_utils import get_blank_id from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--test_data', required=True, help='test data file') parser.add_argument('--data_type', default='raw', # choices=['raw', 'shard'], help='train and cv data type') parser.add_argument('--gpu', type=int, default=-1, help='gpu id for this rank, -1 for cpu') parser.add_argument('--device', type=str, default="cpu", choices=["cpu", "npu", "cuda"], help='accelerator to use') parser.add_argument('--dtype', type=str, default='fp32', choices=['fp16', 'fp32', 'bf16'], help='model\'s dtype') parser.add_argument('--num_workers', default=0, type=int, help='num of subprocess workers for reading') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--beam_size', type=int, default=10, help='beam size for search') parser.add_argument('--length_penalty', type=float, default=0.0, help='length penalty') parser.add_argument('--blank_penalty', type=float, default=0.0, help='blank penalty') parser.add_argument('--result_dir', required=True, help='asr result file') parser.add_argument('--batch_size', type=int, default=16, help='asr result file') parser.add_argument('--modes', nargs='+', help="""decoding mode, support the following: attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring rnnt_greedy_search rnnt_beam_search rnnt_beam_attn_rescoring ctc_beam_td_attn_rescoring hlg_onebest hlg_rescore paraformer_greedy_search paraformer_beam_search""") parser.add_argument('--search_ctc_weight', type=float, default=1.0, help='ctc weight for nbest generation') parser.add_argument('--search_transducer_weight', type=float, default=0.0, help='transducer weight for nbest generation') parser.add_argument('--ctc_weight', type=float, default=0.0, help='ctc weight for rescoring weight in \ attention rescoring decode mode \ ctc weight for rescoring weight in \ transducer attention rescore decode mode') parser.add_argument('--transducer_weight', type=float, default=0.0, help='transducer weight for rescoring weight in ' 'transducer attention rescore mode') parser.add_argument('--attn_weight', type=float, default=0.0, help='attention weight for rescoring weight in ' 'transducer attention rescore mode') parser.add_argument('--decoding_chunk_size', type=int, default=-1, help='''decoding chunk size, <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here''') parser.add_argument('--num_decoding_left_chunks', type=int, default=-1, help='number of left chunks for decoding') parser.add_argument('--simulate_streaming', action='store_true', help='simulate streaming inference') parser.add_argument('--reverse_weight', type=float, default=0.0, help='''right to left weight for attention rescoring decode mode''') parser.add_argument('--override_config', action='append', default=[], help="override yaml config") parser.add_argument('--word', default='', type=str, help='word file, only used for hlg decode') parser.add_argument('--hlg', default='', type=str, help='hlg file, only used for hlg decode') parser.add_argument('--lm_scale', type=float, default=0.0, help='lm scale for hlg attention rescore decode') parser.add_argument('--decoder_scale', type=float, default=0.0, help='lm scale for hlg attention rescore decode') parser.add_argument('--r_decoder_scale', type=float, default=0.0, help='lm scale for hlg attention rescore decode') parser.add_argument( '--context_bias_mode', type=str, default='', help='''Context bias mode, selectable from the following option: decoding-graph, deep-biasing''') parser.add_argument('--context_list_path', type=str, default='', help='Context list path') parser.add_argument('--context_graph_score', type=float, default=0.0, help='''The higher the score, the greater the degree of bias using decoding-graph for biasing''') parser.add_argument('--use_lora', type=bool, default=False, help='''Whether to use lora for biasing''') parser.add_argument("--lora_ckpt_path", default=None, type=str, help="lora checkpoint path.") parser.add_argument('--task', type=str, default='asr', help='Context list path') parser.add_argument('--lang', type=str, default='zh', help='Context list path') args = parser.parse_args() print(args) return args def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') set_random_seed(777) if args.gpu != -1: # remain the original usage of gpu args.device = "cuda" if "cuda" in args.device: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) if len(args.override_config) > 0: configs = override_config(configs, args.override_config) configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False test_conf = copy.deepcopy(configs['dataset_conf']) test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400 test_conf['filter_conf']['min_length'] = 0 test_conf['filter_conf']['token_max_length'] = 102400 test_conf['filter_conf']['token_min_length'] = 0 test_conf['filter_conf']['max_output_input_ratio'] = 102400 test_conf['filter_conf']['min_output_input_ratio'] = 0 test_conf['speed_perturb'] = False test_conf['spec_aug'] = False test_conf['spec_sub'] = False test_conf['spec_trim'] = False test_conf['shuffle'] = True test_conf['sort'] = False test_conf['cycle'] = 1 test_conf['list_shuffle'] = True if 'fbank_conf' in test_conf: test_conf['fbank_conf']['dither'] = 0.0 elif 'mfcc_conf' in test_conf: test_conf['mfcc_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = 1 test_conf['split_num'] = 1 tokenizer = init_tokenizer(configs) test_dataset = Dataset(args.data_type, args.test_data, tokenizer, test_conf, partition=False) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=args.num_workers) # Init asr model from configs args.jit = False model, configs = init_model(args, configs) device = torch.device(args.device) model:LLMASR_Model = model.to(device) model.eval() dtype = torch.float32 if args.dtype == 'fp16': dtype = torch.float16 elif args.dtype == 'bf16': dtype = torch.bfloat16 logging.info("compute dtype is {}".format(dtype)) context_graph = None if 'decoding-graph' in args.context_bias_mode: context_graph = ContextGraph(args.context_list_path, tokenizer.symbol_table, configs['tokenizer_conf']['bpe_path'], args.context_graph_score) _, blank_id = get_blank_id(configs, tokenizer.symbol_table) logging.info("blank_id is {}".format(blank_id)) # TODO(Dinghao Zhou): Support RNN-T related decoding # TODO(Lv Xiang): Support k2 related decoding # TODO(Kaixun Huang): Support context graph files = {} modes = ['llmasr_decode'] for mode in modes: dir_name = os.path.join(args.result_dir, mode) os.makedirs(dir_name, exist_ok=True) file_name = os.path.join(dir_name, 'text') files[mode] = open(file_name, 'w', encoding='utf-8') max_format_len = max([len(mode) for mode in args.modes]) # Get prompt config from gxl_ai_utils.utils import utils_file global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml') with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False): with torch.no_grad(): # logging.info(f'utt_num: {utt_num}') for batch_idx, batch in enumerate(test_data_loader): keys = batch["keys"] feats = batch["feats"].to(device) target = batch["target"].to(device) feats_lengths = batch["feats_lengths"].to(device) target_lengths = batch["target_lengths"].to(device) batch_size = feats.size(0) import random if '><' in args.task: args.task = args.task.replace('><', '> <') if args.task == "" or args.task == "": is_truncation = False else: is_truncation = True random_index = random.randint(0, len(global_prompt_dict[args.task])-1) prompt = global_prompt_dict[args.task][random_index] # print(args.task, prompt) res_text = model.generate(wavs=feats, wavs_len=feats_lengths, prompt=prompt) for mode in modes: line = "{}\t{}".format(keys[0], res_text[0]) files[mode].write(line+'\n') utils_file.logging_print( '{} {} {}'.format(batch_idx, keys[0], res_text[0])) if batch_idx % 100 == 0: for mode, f in files.items(): f.flush() # 强制将缓冲区内容刷新到文件 # if batch_idx >= 1000 and is_truncation: # utils_file.logging_info('采用截断至3000的策略') # break for mode, f in files.items(): f.flush() # 强制将缓冲区内容刷新到文件 f.close() if __name__ == '__main__': main()