Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import print_function | |
import argparse | |
import copy | |
import logging | |
import os | |
import torch | |
import yaml | |
from gxl_ai_utils.utils.utils_model import set_random_seed | |
from torch.utils.data import DataLoader | |
from wenet.dataset.dataset import Dataset | |
from wenet.llm_asr.llmasr_model import LLMASR_Model | |
from wenet.utils.config import override_config | |
from wenet.utils.init_model import init_model | |
from wenet.utils.init_tokenizer import init_tokenizer | |
from wenet.utils.context_graph import ContextGraph | |
from wenet.utils.ctc_utils import get_blank_id | |
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu | |
def get_args(): | |
parser = argparse.ArgumentParser(description='recognize with your model') | |
parser.add_argument('--config', required=True, help='config file') | |
parser.add_argument('--test_data', required=True, help='test data file') | |
parser.add_argument('--data_type', | |
default='raw', | |
# choices=['raw', 'shard'], | |
help='train and cv data type') | |
parser.add_argument('--gpu', | |
type=int, | |
default=-1, | |
help='gpu id for this rank, -1 for cpu') | |
parser.add_argument('--device', | |
type=str, | |
default="cpu", | |
choices=["cpu", "npu", "cuda"], | |
help='accelerator to use') | |
parser.add_argument('--dtype', | |
type=str, | |
default='fp32', | |
choices=['fp16', 'fp32', 'bf16'], | |
help='model\'s dtype') | |
parser.add_argument('--num_workers', | |
default=0, | |
type=int, | |
help='num of subprocess workers for reading') | |
parser.add_argument('--checkpoint', required=True, help='checkpoint model') | |
parser.add_argument('--beam_size', | |
type=int, | |
default=10, | |
help='beam size for search') | |
parser.add_argument('--length_penalty', | |
type=float, | |
default=0.0, | |
help='length penalty') | |
parser.add_argument('--blank_penalty', | |
type=float, | |
default=0.0, | |
help='blank penalty') | |
parser.add_argument('--result_dir', required=True, help='asr result file') | |
parser.add_argument('--batch_size', | |
type=int, | |
default=16, | |
help='asr result file') | |
parser.add_argument('--modes', | |
nargs='+', | |
help="""decoding mode, support the following: | |
attention | |
ctc_greedy_search | |
ctc_prefix_beam_search | |
attention_rescoring | |
rnnt_greedy_search | |
rnnt_beam_search | |
rnnt_beam_attn_rescoring | |
ctc_beam_td_attn_rescoring | |
hlg_onebest | |
hlg_rescore | |
paraformer_greedy_search | |
paraformer_beam_search""") | |
parser.add_argument('--search_ctc_weight', | |
type=float, | |
default=1.0, | |
help='ctc weight for nbest generation') | |
parser.add_argument('--search_transducer_weight', | |
type=float, | |
default=0.0, | |
help='transducer weight for nbest generation') | |
parser.add_argument('--ctc_weight', | |
type=float, | |
default=0.0, | |
help='ctc weight for rescoring weight in \ | |
attention rescoring decode mode \ | |
ctc weight for rescoring weight in \ | |
transducer attention rescore decode mode') | |
parser.add_argument('--transducer_weight', | |
type=float, | |
default=0.0, | |
help='transducer weight for rescoring weight in ' | |
'transducer attention rescore mode') | |
parser.add_argument('--attn_weight', | |
type=float, | |
default=0.0, | |
help='attention weight for rescoring weight in ' | |
'transducer attention rescore mode') | |
parser.add_argument('--decoding_chunk_size', | |
type=int, | |
default=-1, | |
help='''decoding chunk size, | |
<0: for decoding, use full chunk. | |
>0: for decoding, use fixed chunk size as set. | |
0: used for training, it's prohibited here''') | |
parser.add_argument('--num_decoding_left_chunks', | |
type=int, | |
default=-1, | |
help='number of left chunks for decoding') | |
parser.add_argument('--simulate_streaming', | |
action='store_true', | |
help='simulate streaming inference') | |
parser.add_argument('--reverse_weight', | |
type=float, | |
default=0.0, | |
help='''right to left weight for attention rescoring | |
decode mode''') | |
parser.add_argument('--override_config', | |
action='append', | |
default=[], | |
help="override yaml config") | |
parser.add_argument('--word', | |
default='', | |
type=str, | |
help='word file, only used for hlg decode') | |
parser.add_argument('--hlg', | |
default='', | |
type=str, | |
help='hlg file, only used for hlg decode') | |
parser.add_argument('--lm_scale', | |
type=float, | |
default=0.0, | |
help='lm scale for hlg attention rescore decode') | |
parser.add_argument('--decoder_scale', | |
type=float, | |
default=0.0, | |
help='lm scale for hlg attention rescore decode') | |
parser.add_argument('--r_decoder_scale', | |
type=float, | |
default=0.0, | |
help='lm scale for hlg attention rescore decode') | |
parser.add_argument( | |
'--context_bias_mode', | |
type=str, | |
default='', | |
help='''Context bias mode, selectable from the following | |
option: decoding-graph, deep-biasing''') | |
parser.add_argument('--context_list_path', | |
type=str, | |
default='', | |
help='Context list path') | |
parser.add_argument('--context_graph_score', | |
type=float, | |
default=0.0, | |
help='''The higher the score, the greater the degree of | |
bias using decoding-graph for biasing''') | |
parser.add_argument('--use_lora', | |
type=bool, | |
default=False, | |
help='''Whether to use lora for biasing''') | |
parser.add_argument("--lora_ckpt_path", | |
default=None, | |
type=str, | |
help="lora checkpoint path.") | |
parser.add_argument('--task', | |
type=str, | |
default='asr', | |
help='Context list path') | |
parser.add_argument('--lang', | |
type=str, | |
default='zh', | |
help='Context list path') | |
args = parser.parse_args() | |
print(args) | |
return args | |
def main(): | |
args = get_args() | |
logging.basicConfig(level=logging.DEBUG, | |
format='%(asctime)s %(levelname)s %(message)s') | |
set_random_seed(777) | |
if args.gpu != -1: | |
# remain the original usage of gpu | |
args.device = "cuda" | |
if "cuda" in args.device: | |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) | |
with open(args.config, 'r') as fin: | |
configs = yaml.load(fin, Loader=yaml.FullLoader) | |
if len(args.override_config) > 0: | |
configs = override_config(configs, args.override_config) | |
configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False | |
test_conf = copy.deepcopy(configs['dataset_conf']) | |
test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400 | |
test_conf['filter_conf']['min_length'] = 0 | |
test_conf['filter_conf']['token_max_length'] = 102400 | |
test_conf['filter_conf']['token_min_length'] = 0 | |
test_conf['filter_conf']['max_output_input_ratio'] = 102400 | |
test_conf['filter_conf']['min_output_input_ratio'] = 0 | |
test_conf['speed_perturb'] = False | |
test_conf['spec_aug'] = False | |
test_conf['spec_sub'] = False | |
test_conf['spec_trim'] = False | |
test_conf['shuffle'] = True | |
test_conf['sort'] = False | |
test_conf['cycle'] = 1 | |
test_conf['list_shuffle'] = True | |
if 'fbank_conf' in test_conf: | |
test_conf['fbank_conf']['dither'] = 0.0 | |
elif 'mfcc_conf' in test_conf: | |
test_conf['mfcc_conf']['dither'] = 0.0 | |
test_conf['batch_conf']['batch_type'] = "static" | |
test_conf['batch_conf']['batch_size'] = 1 | |
test_conf['split_num'] = 1 | |
tokenizer = init_tokenizer(configs) | |
test_dataset = Dataset(args.data_type, | |
args.test_data, | |
tokenizer, | |
test_conf, | |
partition=False) | |
test_data_loader = DataLoader(test_dataset, | |
batch_size=None, | |
num_workers=args.num_workers) | |
# Init asr model from configs | |
args.jit = False | |
model, configs = init_model(args, configs) | |
device = torch.device(args.device) | |
model:LLMASR_Model = model.to(device) | |
model.eval() | |
dtype = torch.float32 | |
if args.dtype == 'fp16': | |
dtype = torch.float16 | |
elif args.dtype == 'bf16': | |
dtype = torch.bfloat16 | |
logging.info("compute dtype is {}".format(dtype)) | |
context_graph = None | |
if 'decoding-graph' in args.context_bias_mode: | |
context_graph = ContextGraph(args.context_list_path, | |
tokenizer.symbol_table, | |
configs['tokenizer_conf']['bpe_path'], | |
args.context_graph_score) | |
_, blank_id = get_blank_id(configs, tokenizer.symbol_table) | |
logging.info("blank_id is {}".format(blank_id)) | |
# TODO(Dinghao Zhou): Support RNN-T related decoding | |
# TODO(Lv Xiang): Support k2 related decoding | |
# TODO(Kaixun Huang): Support context graph | |
files = {} | |
modes = ['llmasr_decode'] | |
for mode in modes: | |
dir_name = os.path.join(args.result_dir, mode) | |
os.makedirs(dir_name, exist_ok=True) | |
file_name = os.path.join(dir_name, 'text') | |
files[mode] = open(file_name, 'w', encoding='utf-8') | |
max_format_len = max([len(mode) for mode in args.modes]) | |
# Get prompt config | |
from gxl_ai_utils.utils import utils_file | |
global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml') | |
with torch.cuda.amp.autocast(enabled=True, | |
dtype=dtype, | |
cache_enabled=False): | |
with torch.no_grad(): | |
# logging.info(f'utt_num: {utt_num}') | |
for batch_idx, batch in enumerate(test_data_loader): | |
keys = batch["keys"] | |
feats = batch["feats"].to(device) | |
target = batch["target"].to(device) | |
feats_lengths = batch["feats_lengths"].to(device) | |
target_lengths = batch["target_lengths"].to(device) | |
batch_size = feats.size(0) | |
import random | |
if '><' in args.task: | |
args.task = args.task.replace('><', '> <') | |
if args.task == "<TRANSCRIBE>" or args.task == "<transcribe>": | |
is_truncation = False | |
else: | |
is_truncation = True | |
random_index = random.randint(0, len(global_prompt_dict[args.task])-1) | |
prompt = global_prompt_dict[args.task][random_index] | |
# print(args.task, prompt) | |
res_text = model.generate(wavs=feats, wavs_len=feats_lengths, prompt=prompt) | |
for mode in modes: | |
line = "{}\t{}".format(keys[0], res_text[0]) | |
files[mode].write(line+'\n') | |
utils_file.logging_print( '{} {} {}'.format(batch_idx, keys[0], res_text[0])) | |
if batch_idx % 100 == 0: | |
for mode, f in files.items(): | |
f.flush() # 强制将缓冲区内容刷新到文件 | |
# if batch_idx >= 1000 and is_truncation: | |
# utils_file.logging_info('采用截断至3000的策略') | |
# break | |
for mode, f in files.items(): | |
f.flush() # 强制将缓冲区内容刷新到文件 | |
f.close() | |
if __name__ == '__main__': | |
main() | |