#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from pathlib import Path import sys import tempfile pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import librosa import numpy as np import sherpa from scipy.io import wavfile import torch import torchaudio from project_settings import project_path, temp_directory from toolbox.k2_sherpa.utils import audio_convert from toolbox.k2_sherpa import decode, models def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_dir", default=(project_path / "pretrained_models/huggingface/csukuangfj/wenet-chinese-model").as_posix(), type=str ) parser.add_argument( "--in_filename", default=(project_path / "data/test_wavs/paraformer-zh/si_chuan_hua.wav").as_posix(), type=str ) parser.add_argument("--sample_rate", default=16000, type=int) args = parser.parse_args() return args def main(): args = get_args() # audio convert signal, sample_rate = librosa.load(args.in_filename, sr=args.sample_rate) signal *= 32768.0 signal = np.array(signal, dtype=np.int16) temp_file = temp_directory / "temp.wav" wavfile.write( temp_file.as_posix(), rate=args.sample_rate, data=signal ) # audio convert # in_filename = Path(args.in_filename) # out_filename = Path(tempfile.gettempdir()) / "asr" / in_filename.name # out_filename.parent.mkdir(parents=True, exist_ok=True) # # audio_convert(in_filename=in_filename.as_posix(), # out_filename=out_filename.as_posix(), # ) # load recognizer m_dict = models.model_map["Chinese"][0] local_model_dir = Path(args.model_dir) nn_model_file = local_model_dir / m_dict["nn_model_file"] tokens_file = local_model_dir / m_dict["tokens_file"] # recognizer = models.load_recognizer( # repo_id=m_dict["repo_id"], # nn_model_file=nn_model_file.as_posix(), # tokens_file=tokens_file.as_posix(), # sub_folder=m_dict["sub_folder"], # local_model_dir=local_model_dir, # recognizer_type=m_dict["recognizer_type"], # decoding_method="greedy_search", # num_active_paths=2, # ) feat_config = sherpa.FeatureConfig(normalize_samples=False) feat_config.fbank_opts.frame_opts.samp_freq = args.sample_rate feat_config.fbank_opts.mel_opts.num_bins = 80 feat_config.fbank_opts.frame_opts.dither = 0 config = sherpa.OfflineRecognizerConfig( nn_model=nn_model_file.as_posix(), tokens=tokens_file.as_posix(), use_gpu=False, feat_config=feat_config, decoding_method="greedy_search", num_active_paths=2, ) recognizer = sherpa.OfflineRecognizer(config) # s = recognizer.create_stream() # s.accept_wave_file( # temp_file.as_posix() # ) # recognizer.decode_stream(s) # text = s.result.text.strip() # text = text.lower() # print("text: {}".format(text)) text = decode.decode_by_recognizer(recognizer=recognizer, filename=temp_file.as_posix(), ) print("text: {}".format(text)) return if __name__ == "__main__": main()