#!/bin/env python3 """ Real-time ASR using microphone """ import argparse import logging import sherpa_onnx import os import time import struct import asyncio import soundfile try: import pyaudio except ImportError: raise ImportError('Please install pyaudio with `pip install pyaudio`') logger = logging.getLogger(__name__) SAMPLE_RATE = 16000 pactx = pyaudio.PyAudio() models_root: str = None num_threads: int = 1 def create_zipformer(args) -> sherpa_onnx.OnlineRecognizer: d = os.path.join( models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20') encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx") decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx") joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx") tokens = os.path.join(d, "tokens.txt") recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( tokens=tokens, encoder=encoder, decoder=decoder, joiner=joiner, provider=args.provider, num_threads=num_threads, sample_rate=SAMPLE_RATE, feature_dim=80, enable_endpoint_detection=True, rule1_min_trailing_silence=2.4, rule2_min_trailing_silence=1.2, rule3_min_utterance_length=20, # it essentially disables this rule ) return recognizer def create_sensevoice(args) -> sherpa_onnx.OfflineRecognizer: model = os.path.join( models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'model.onnx') tokens = os.path.join( models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'tokens.txt') recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( model=model, tokens=tokens, num_threads=num_threads, use_itn=True, debug=0, language=args.lang, ) return recognizer async def run_online(buf, recognizer): stream = recognizer.create_stream() last_result = "" segment_id = 0 logger.info('Start real-time recognizer') while True: samples = await buf.get() stream.accept_waveform(SAMPLE_RATE, samples) while recognizer.is_ready(stream): recognizer.decode_stream(stream) is_endpoint = recognizer.is_endpoint(stream) result = recognizer.get_result(stream) if result and (last_result != result): last_result = result logger.info(f' > {segment_id}:{result}') if is_endpoint: if result: logger.info(f'{segment_id}: {result}') segment_id += 1 recognizer.reset(stream) async def run_offline(buf, recognizer): config = sherpa_onnx.VadModelConfig() config.silero_vad.model = os.path.join( models_root, 'silero_vad', 'silero_vad.onnx') config.silero_vad.min_silence_duration = 0.25 config.sample_rate = SAMPLE_RATE vad = sherpa_onnx.VoiceActivityDetector( config, buffer_size_in_seconds=100) logger.info('Start offline recognizer with VAD') texts = [] while True: samples = await buf.get() vad.accept_waveform(samples) while not vad.empty(): stream = recognizer.create_stream() stream.accept_waveform(SAMPLE_RATE, vad.front.samples) vad.pop() recognizer.decode_stream(stream) text = stream.result.text.strip().lower() if len(text): idx = len(texts) texts.append(text) logger.info(f"{idx}: {text}") async def handle_asr(args): action_func = None if args.model == 'zipformer': recognizer = create_zipformer(args) action_func = run_online elif args.model == 'sensevoice': recognizer = create_sensevoice(args) action_func = run_offline else: raise ValueError(f'Unknown model: {args.model}') buf = asyncio.Queue() recorder_task = asyncio.create_task(run_record(buf)) asr_task = asyncio.create_task(action_func(buf, recognizer)) await asyncio.gather(asr_task, recorder_task) async def handle_tts(args): model = os.path.join( models_root, 'vits-melo-tts-zh_en', 'model.onnx') lexicon = os.path.join( models_root, 'vits-melo-tts-zh_en', 'lexicon.txt') dict_dir = os.path.join( models_root, 'vits-melo-tts-zh_en', 'dict') tokens = os.path.join( models_root, 'vits-melo-tts-zh_en', 'tokens.txt') tts_config = sherpa_onnx.OfflineTtsConfig( model=sherpa_onnx.OfflineTtsModelConfig( vits=sherpa_onnx.OfflineTtsVitsModelConfig( model=model, lexicon=lexicon, dict_dir=dict_dir, tokens=tokens, ), provider=args.provider, debug=0, num_threads=num_threads, ), max_num_sentences=args.max_num_sentences, ) if not tts_config.validate(): raise ValueError("Please check your config") tts = sherpa_onnx.OfflineTts(tts_config) start = time.time() audio = tts.generate(args.text, sid=args.sid, speed=args.speed) elapsed_seconds = time.time() - start audio_duration = len(audio.samples) / audio.sample_rate real_time_factor = elapsed_seconds / audio_duration if args.output: logger.info(f"Saved to {args.output}") soundfile.write( args.output, audio.samples, samplerate=audio.sample_rate, subtype="PCM_16", ) logger.info(f"The text is '{args.text}'") logger.info(f"Elapsed seconds: {elapsed_seconds:.3f}") logger.info(f"Audio duration in seconds: {audio_duration:.3f}") logger.info( f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") async def run_record(buf: asyncio.Queue[list[float]]): loop = asyncio.get_event_loop() def on_input(in_data, frame_count, time_info, status): samples = [ v/32768.0 for v in list(struct.unpack('<' + 'h' * frame_count, in_data))] loop.create_task(buf.put(samples)) return (None, pyaudio.paContinue) frame_size = 320 recorder = pactx.open(format=pyaudio.paInt16, channels=1, rate=SAMPLE_RATE, input=True, frames_per_buffer=frame_size, stream_callback=on_input) recorder.start_stream() logger.info('Start recording') while recorder.is_active(): await asyncio.sleep(0.1) async def main(): parser = argparse.ArgumentParser() parser.add_argument('--provider', default='cpu', help='onnxruntime provider, default is cpu, use cuda for GPU') subparsers = parser.add_subparsers(help='commands help') asr_parser = subparsers.add_parser('asr', help='run asr mode') asr_parser.add_argument('--model', default='zipformer', help='model name, default is zipformer') asr_parser.add_argument('--lang', default='zh', help='language, default is zh') asr_parser.set_defaults(func=handle_asr) tts_parser = subparsers.add_parser('tts', help='run tts mode') tts_parser.add_argument('--sid', type=int, default=0, help="""Speaker ID. Used only for multi-speaker models, e.g. models trained using the VCTK dataset. Not used for single-speaker models, e.g., models trained using the LJ speech dataset. """) tts_parser.add_argument('--output', type=str, default='output.wav', help='output file name, default is output.wav') tts_parser.add_argument( "--speed", type=float, default=1.0, help="Speech speed. Larger->faster; smaller->slower", ) tts_parser.add_argument( "--max-num-sentences", type=int, default=2, help="""Max number of sentences in a batch to avoid OOM if the input text is very long. Set it to -1 to process all the sentences in a single batch. A smaller value does not mean it is slower compared to a larger one on CPU. """, ) tts_parser.add_argument( "text", type=str, help="The input text to generate audio for", ) tts_parser.set_defaults(func=handle_tts) args = parser.parse_args() if hasattr(args, 'func'): await args.func(args) else: parser.print_help() if __name__ == '__main__': logging.basicConfig( format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s') logging.getLogger().setLevel(logging.INFO) painfo = pactx.get_default_input_device_info() assert painfo['maxInputChannels'] >= 1, 'No input device' logger.info('Default input device: %s', painfo['name']) for d in ['.', '..', '../..']: if os.path.isdir(f'{d}/models'): models_root = f'{d}/models' break assert models_root, 'Could not find models directory' asyncio.run(main())