from typing import * import logging import time import logging import sherpa_onnx import os import asyncio import numpy as np logger = logging.getLogger(__file__) _asr_engines = {} class ASRResult: def __init__(self, text: str, finished: bool, idx: int): self.text = text self.finished = finished self.idx = idx def to_dict(self): return {"text": self.text, "finished": self.finished, "idx": self.idx} class ASRStream: def __init__(self, recognizer: Union[sherpa_onnx.OnlineRecognizer | sherpa_onnx.OfflineRecognizer], sample_rate: int) -> None: self.recognizer = recognizer self.inbuf = asyncio.Queue() self.outbuf = asyncio.Queue() self.sample_rate = sample_rate self.is_closed = False self.online = isinstance(recognizer, sherpa_onnx.OnlineRecognizer) async def start(self): if self.online: asyncio.create_task(self.run_online()) else: asyncio.create_task(self.run_offline()) async def run_online(self): stream = self.recognizer.create_stream() last_result = "" segment_id = 0 logger.info('asr: start real-time recognizer') while not self.is_closed: samples = await self.inbuf.get() stream.accept_waveform(self.sample_rate, samples) while self.recognizer.is_ready(stream): self.recognizer.decode_stream(stream) is_endpoint = self.recognizer.is_endpoint(stream) result = self.recognizer.get_result(stream) if result and (last_result != result): last_result = result logger.info(f' > {segment_id}:{result}') self.outbuf.put_nowait( ASRResult(result, False, segment_id)) if is_endpoint: if result: logger.info(f'{segment_id}: {result}') self.outbuf.put_nowait( ASRResult(result, True, segment_id)) segment_id += 1 self.recognizer.reset(stream) async def run_offline(self): vad = _asr_engines['vad'] segment_id = 0 st = None while not self.is_closed: samples = await self.inbuf.get() vad.accept_waveform(samples) while not vad.empty(): if not st: st = time.time() stream = self.recognizer.create_stream() stream.accept_waveform(self.sample_rate, vad.front.samples) vad.pop() self.recognizer.decode_stream(stream) result = stream.result.text.strip() if result: duration = time.time() - st logger.info(f'{segment_id}:{result} ({duration:.2f}s)') self.outbuf.put_nowait(ASRResult(result, True, segment_id)) segment_id += 1 st = None async def close(self): self.is_closed = True self.outbuf.put_nowait(None) async def write(self, pcm_bytes: bytes): pcm_data = np.frombuffer(pcm_bytes, dtype=np.int16) samples = pcm_data.astype(np.float32) / 32768.0 self.inbuf.put_nowait(samples) async def read(self) -> ASRResult: return await self.outbuf.get() def create_zipformer(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: d = os.path.join( args.models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20') if not os.path.exists(d): raise ValueError(f"asr: model not found {d}") encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx") decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx") joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx") tokens = os.path.join(d, "tokens.txt") recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( tokens=tokens, encoder=encoder, decoder=decoder, joiner=joiner, provider=args.asr_provider, num_threads=args.threads, sample_rate=samplerate, feature_dim=80, enable_endpoint_detection=True, rule1_min_trailing_silence=2.4, rule2_min_trailing_silence=1.2, rule3_min_utterance_length=20, # it essentially disables this rule ) return recognizer def create_sensevoice(samplerate: int, args) -> sherpa_onnx.OfflineRecognizer: d = os.path.join(args.models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17') if not os.path.exists(d): raise ValueError(f"asr: model not found {d}") recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( model=os.path.join(d, 'model.onnx'), tokens=os.path.join(d, 'tokens.txt'), num_threads=args.threads, sample_rate=samplerate, use_itn=True, debug=0, language=args.asr_lang, ) return recognizer def create_paraformer_trilingual(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: d = os.path.join( args.models_root, 'sherpa-onnx-paraformer-trilingual-zh-cantonese-en') if not os.path.exists(d): raise ValueError(f"asr: model not found {d}") recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=os.path.join(d, 'model.onnx'), tokens=os.path.join(d, 'tokens.txt'), num_threads=args.threads, sample_rate=samplerate, debug=0, provider=args.asr_provider, ) return recognizer def create_paraformer_en(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: d = os.path.join( args.models_root, 'sherpa-onnx-paraformer-en') if not os.path.exists(d): raise ValueError(f"asr: model not found {d}") recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=os.path.join(d, 'model.onnx'), tokens=os.path.join(d, 'tokens.txt'), num_threads=args.threads, sample_rate=samplerate, use_itn=True, debug=0, provider=args.asr_provider, ) return recognizer def load_asr_engine(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: cache_engine = _asr_engines.get(args.asr_model) if cache_engine: return cache_engine st = time.time() if args.asr_model == 'zipformer-bilingual': cache_engine = create_zipformer(samplerate, args) elif args.asr_model == 'sensevoice': cache_engine = create_sensevoice(samplerate, args) _asr_engines['vad'] = load_vad_engine(samplerate, args) elif args.asr_model == 'paraformer-trilingual': cache_engine = create_paraformer_trilingual(samplerate, args) _asr_engines['vad'] = load_vad_engine(samplerate, args) elif args.asr_model == 'paraformer-en': cache_engine = create_paraformer_en(samplerate, args) _asr_engines['vad'] = load_vad_engine(samplerate, args) else: raise ValueError(f"asr: unknown model {args.asr_model}") _asr_engines[args.asr_model] = cache_engine logger.info(f"asr: engine loaded in {time.time() - st:.2f}s") return cache_engine def load_vad_engine(samplerate: int, args, min_silence_duration: float = 0.25, buffer_size_in_seconds: int = 100) -> sherpa_onnx.VoiceActivityDetector: config = sherpa_onnx.VadModelConfig() d = os.path.join(args.models_root, 'silero_vad') if not os.path.exists(d): raise ValueError(f"vad: model not found {d}") config.silero_vad.model = os.path.join(d, 'silero_vad.onnx') config.silero_vad.min_silence_duration = min_silence_duration config.sample_rate = samplerate config.provider = args.asr_provider config.num_threads = args.threads vad = sherpa_onnx.VoiceActivityDetector( config, buffer_size_in_seconds=buffer_size_in_seconds) return vad async def start_asr_stream(samplerate: int, args) -> ASRStream: """ Start a ASR stream """ stream = ASRStream(load_asr_engine(samplerate, args), samplerate) await stream.start() return stream