|
from typing import * |
|
import logging |
|
import time |
|
import logging |
|
import sherpa_onnx |
|
import os |
|
import asyncio |
|
import numpy as np |
|
|
|
logger = logging.getLogger(__file__) |
|
_asr_engines = {} |
|
|
|
|
|
class ASRResult: |
|
def __init__(self, text: str, finished: bool, idx: int): |
|
self.text = text |
|
self.finished = finished |
|
self.idx = idx |
|
|
|
def to_dict(self): |
|
return {"text": self.text, "finished": self.finished, "idx": self.idx} |
|
|
|
|
|
class ASRStream: |
|
def __init__(self, recognizer: Union[sherpa_onnx.OnlineRecognizer | sherpa_onnx.OfflineRecognizer], sample_rate: int) -> None: |
|
self.recognizer = recognizer |
|
self.inbuf = asyncio.Queue() |
|
self.outbuf = asyncio.Queue() |
|
self.sample_rate = sample_rate |
|
self.is_closed = False |
|
self.online = isinstance(recognizer, sherpa_onnx.OnlineRecognizer) |
|
|
|
async def start(self): |
|
if self.online: |
|
asyncio.create_task(self.run_online()) |
|
else: |
|
asyncio.create_task(self.run_offline()) |
|
|
|
async def run_online(self): |
|
stream = self.recognizer.create_stream() |
|
last_result = "" |
|
segment_id = 0 |
|
logger.info('asr: start real-time recognizer') |
|
while not self.is_closed: |
|
samples = await self.inbuf.get() |
|
stream.accept_waveform(self.sample_rate, samples) |
|
while self.recognizer.is_ready(stream): |
|
self.recognizer.decode_stream(stream) |
|
|
|
is_endpoint = self.recognizer.is_endpoint(stream) |
|
result = self.recognizer.get_result(stream) |
|
|
|
if result and (last_result != result): |
|
last_result = result |
|
logger.info(f' > {segment_id}:{result}') |
|
self.outbuf.put_nowait( |
|
ASRResult(result, False, segment_id)) |
|
|
|
if is_endpoint: |
|
if result: |
|
logger.info(f'{segment_id}: {result}') |
|
self.outbuf.put_nowait( |
|
ASRResult(result, True, segment_id)) |
|
segment_id += 1 |
|
self.recognizer.reset(stream) |
|
|
|
async def run_offline(self): |
|
vad = _asr_engines['vad'] |
|
segment_id = 0 |
|
st = None |
|
while not self.is_closed: |
|
samples = await self.inbuf.get() |
|
vad.accept_waveform(samples) |
|
while not vad.empty(): |
|
if not st: |
|
st = time.time() |
|
stream = self.recognizer.create_stream() |
|
stream.accept_waveform(self.sample_rate, vad.front.samples) |
|
|
|
vad.pop() |
|
self.recognizer.decode_stream(stream) |
|
|
|
result = stream.result.text.strip() |
|
if result: |
|
duration = time.time() - st |
|
logger.info(f'{segment_id}:{result} ({duration:.2f}s)') |
|
self.outbuf.put_nowait(ASRResult(result, True, segment_id)) |
|
segment_id += 1 |
|
st = None |
|
|
|
async def close(self): |
|
self.is_closed = True |
|
self.outbuf.put_nowait(None) |
|
|
|
async def write(self, pcm_bytes: bytes): |
|
pcm_data = np.frombuffer(pcm_bytes, dtype=np.int16) |
|
samples = pcm_data.astype(np.float32) / 32768.0 |
|
self.inbuf.put_nowait(samples) |
|
|
|
async def read(self) -> ASRResult: |
|
return await self.outbuf.get() |
|
|
|
|
|
def create_zipformer(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: |
|
d = os.path.join( |
|
args.models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20') |
|
if not os.path.exists(d): |
|
raise ValueError(f"asr: model not found {d}") |
|
|
|
encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx") |
|
decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx") |
|
joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx") |
|
tokens = os.path.join(d, "tokens.txt") |
|
|
|
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( |
|
tokens=tokens, |
|
encoder=encoder, |
|
decoder=decoder, |
|
joiner=joiner, |
|
provider=args.asr_provider, |
|
num_threads=args.threads, |
|
sample_rate=samplerate, |
|
feature_dim=80, |
|
enable_endpoint_detection=True, |
|
rule1_min_trailing_silence=2.4, |
|
rule2_min_trailing_silence=1.2, |
|
rule3_min_utterance_length=20, |
|
) |
|
return recognizer |
|
|
|
|
|
def create_sensevoice(samplerate: int, args) -> sherpa_onnx.OfflineRecognizer: |
|
d = os.path.join(args.models_root, |
|
'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17') |
|
|
|
if not os.path.exists(d): |
|
raise ValueError(f"asr: model not found {d}") |
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( |
|
model=os.path.join(d, 'model.onnx'), |
|
tokens=os.path.join(d, 'tokens.txt'), |
|
num_threads=args.threads, |
|
sample_rate=samplerate, |
|
use_itn=True, |
|
debug=0, |
|
language=args.asr_lang, |
|
) |
|
return recognizer |
|
|
|
|
|
def create_paraformer_trilingual(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: |
|
d = os.path.join( |
|
args.models_root, 'sherpa-onnx-paraformer-trilingual-zh-cantonese-en') |
|
if not os.path.exists(d): |
|
raise ValueError(f"asr: model not found {d}") |
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( |
|
paraformer=os.path.join(d, 'model.onnx'), |
|
tokens=os.path.join(d, 'tokens.txt'), |
|
num_threads=args.threads, |
|
sample_rate=samplerate, |
|
debug=0, |
|
provider=args.asr_provider, |
|
) |
|
return recognizer |
|
|
|
|
|
def create_paraformer_en(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: |
|
d = os.path.join( |
|
args.models_root, 'sherpa-onnx-paraformer-en') |
|
if not os.path.exists(d): |
|
raise ValueError(f"asr: model not found {d}") |
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( |
|
paraformer=os.path.join(d, 'model.onnx'), |
|
tokens=os.path.join(d, 'tokens.txt'), |
|
num_threads=args.threads, |
|
sample_rate=samplerate, |
|
use_itn=True, |
|
debug=0, |
|
provider=args.asr_provider, |
|
) |
|
return recognizer |
|
|
|
|
|
def load_asr_engine(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: |
|
cache_engine = _asr_engines.get(args.asr_model) |
|
if cache_engine: |
|
return cache_engine |
|
st = time.time() |
|
if args.asr_model == 'zipformer-bilingual': |
|
cache_engine = create_zipformer(samplerate, args) |
|
elif args.asr_model == 'sensevoice': |
|
cache_engine = create_sensevoice(samplerate, args) |
|
_asr_engines['vad'] = load_vad_engine(samplerate, args) |
|
elif args.asr_model == 'paraformer-trilingual': |
|
cache_engine = create_paraformer_trilingual(samplerate, args) |
|
_asr_engines['vad'] = load_vad_engine(samplerate, args) |
|
elif args.asr_model == 'paraformer-en': |
|
cache_engine = create_paraformer_en(samplerate, args) |
|
_asr_engines['vad'] = load_vad_engine(samplerate, args) |
|
else: |
|
raise ValueError(f"asr: unknown model {args.asr_model}") |
|
_asr_engines[args.asr_model] = cache_engine |
|
logger.info(f"asr: engine loaded in {time.time() - st:.2f}s") |
|
return cache_engine |
|
|
|
|
|
def load_vad_engine(samplerate: int, args, min_silence_duration: float = 0.25, buffer_size_in_seconds: int = 100) -> sherpa_onnx.VoiceActivityDetector: |
|
config = sherpa_onnx.VadModelConfig() |
|
d = os.path.join(args.models_root, 'silero_vad') |
|
if not os.path.exists(d): |
|
raise ValueError(f"vad: model not found {d}") |
|
|
|
config.silero_vad.model = os.path.join(d, 'silero_vad.onnx') |
|
config.silero_vad.min_silence_duration = min_silence_duration |
|
config.sample_rate = samplerate |
|
config.provider = args.asr_provider |
|
config.num_threads = args.threads |
|
|
|
vad = sherpa_onnx.VoiceActivityDetector( |
|
config, |
|
buffer_size_in_seconds=buffer_size_in_seconds) |
|
return vad |
|
|
|
|
|
async def start_asr_stream(samplerate: int, args) -> ASRStream: |
|
""" |
|
Start a ASR stream |
|
""" |
|
stream = ASRStream(load_asr_engine(samplerate, args), samplerate) |
|
await stream.start() |
|
return stream |
|
|