dh1 / asr.py
cuio's picture
Upload 11 files
da8e0c5 verified
raw
history blame
8.11 kB
from typing import *
import logging
import time
import logging
import sherpa_onnx
import os
import asyncio
import numpy as np
logger = logging.getLogger(__file__)
_asr_engines = {}
class ASRResult:
def __init__(self, text: str, finished: bool, idx: int):
self.text = text
self.finished = finished
self.idx = idx
def to_dict(self):
return {"text": self.text, "finished": self.finished, "idx": self.idx}
class ASRStream:
def __init__(self, recognizer: Union[sherpa_onnx.OnlineRecognizer | sherpa_onnx.OfflineRecognizer], sample_rate: int) -> None:
self.recognizer = recognizer
self.inbuf = asyncio.Queue()
self.outbuf = asyncio.Queue()
self.sample_rate = sample_rate
self.is_closed = False
self.online = isinstance(recognizer, sherpa_onnx.OnlineRecognizer)
async def start(self):
if self.online:
asyncio.create_task(self.run_online())
else:
asyncio.create_task(self.run_offline())
async def run_online(self):
stream = self.recognizer.create_stream()
last_result = ""
segment_id = 0
logger.info('asr: start real-time recognizer')
while not self.is_closed:
samples = await self.inbuf.get()
stream.accept_waveform(self.sample_rate, samples)
while self.recognizer.is_ready(stream):
self.recognizer.decode_stream(stream)
is_endpoint = self.recognizer.is_endpoint(stream)
result = self.recognizer.get_result(stream)
if result and (last_result != result):
last_result = result
logger.info(f' > {segment_id}:{result}')
self.outbuf.put_nowait(
ASRResult(result, False, segment_id))
if is_endpoint:
if result:
logger.info(f'{segment_id}: {result}')
self.outbuf.put_nowait(
ASRResult(result, True, segment_id))
segment_id += 1
self.recognizer.reset(stream)
async def run_offline(self):
vad = _asr_engines['vad']
segment_id = 0
st = None
while not self.is_closed:
samples = await self.inbuf.get()
vad.accept_waveform(samples)
while not vad.empty():
if not st:
st = time.time()
stream = self.recognizer.create_stream()
stream.accept_waveform(self.sample_rate, vad.front.samples)
vad.pop()
self.recognizer.decode_stream(stream)
result = stream.result.text.strip()
if result:
duration = time.time() - st
logger.info(f'{segment_id}:{result} ({duration:.2f}s)')
self.outbuf.put_nowait(ASRResult(result, True, segment_id))
segment_id += 1
st = None
async def close(self):
self.is_closed = True
self.outbuf.put_nowait(None)
async def write(self, pcm_bytes: bytes):
pcm_data = np.frombuffer(pcm_bytes, dtype=np.int16)
samples = pcm_data.astype(np.float32) / 32768.0
self.inbuf.put_nowait(samples)
async def read(self) -> ASRResult:
return await self.outbuf.get()
def create_zipformer(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
d = os.path.join(
args.models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20')
if not os.path.exists(d):
raise ValueError(f"asr: model not found {d}")
encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx")
decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx")
joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx")
tokens = os.path.join(d, "tokens.txt")
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=tokens,
encoder=encoder,
decoder=decoder,
joiner=joiner,
provider=args.asr_provider,
num_threads=args.threads,
sample_rate=samplerate,
feature_dim=80,
enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=20, # it essentially disables this rule
)
return recognizer
def create_sensevoice(samplerate: int, args) -> sherpa_onnx.OfflineRecognizer:
d = os.path.join(args.models_root,
'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17')
if not os.path.exists(d):
raise ValueError(f"asr: model not found {d}")
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
model=os.path.join(d, 'model.onnx'),
tokens=os.path.join(d, 'tokens.txt'),
num_threads=args.threads,
sample_rate=samplerate,
use_itn=True,
debug=0,
language=args.asr_lang,
)
return recognizer
def create_paraformer_trilingual(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
d = os.path.join(
args.models_root, 'sherpa-onnx-paraformer-trilingual-zh-cantonese-en')
if not os.path.exists(d):
raise ValueError(f"asr: model not found {d}")
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=os.path.join(d, 'model.onnx'),
tokens=os.path.join(d, 'tokens.txt'),
num_threads=args.threads,
sample_rate=samplerate,
debug=0,
provider=args.asr_provider,
)
return recognizer
def create_paraformer_en(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
d = os.path.join(
args.models_root, 'sherpa-onnx-paraformer-en')
if not os.path.exists(d):
raise ValueError(f"asr: model not found {d}")
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=os.path.join(d, 'model.onnx'),
tokens=os.path.join(d, 'tokens.txt'),
num_threads=args.threads,
sample_rate=samplerate,
use_itn=True,
debug=0,
provider=args.asr_provider,
)
return recognizer
def load_asr_engine(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
cache_engine = _asr_engines.get(args.asr_model)
if cache_engine:
return cache_engine
st = time.time()
if args.asr_model == 'zipformer-bilingual':
cache_engine = create_zipformer(samplerate, args)
elif args.asr_model == 'sensevoice':
cache_engine = create_sensevoice(samplerate, args)
_asr_engines['vad'] = load_vad_engine(samplerate, args)
elif args.asr_model == 'paraformer-trilingual':
cache_engine = create_paraformer_trilingual(samplerate, args)
_asr_engines['vad'] = load_vad_engine(samplerate, args)
elif args.asr_model == 'paraformer-en':
cache_engine = create_paraformer_en(samplerate, args)
_asr_engines['vad'] = load_vad_engine(samplerate, args)
else:
raise ValueError(f"asr: unknown model {args.asr_model}")
_asr_engines[args.asr_model] = cache_engine
logger.info(f"asr: engine loaded in {time.time() - st:.2f}s")
return cache_engine
def load_vad_engine(samplerate: int, args, min_silence_duration: float = 0.25, buffer_size_in_seconds: int = 100) -> sherpa_onnx.VoiceActivityDetector:
config = sherpa_onnx.VadModelConfig()
d = os.path.join(args.models_root, 'silero_vad')
if not os.path.exists(d):
raise ValueError(f"vad: model not found {d}")
config.silero_vad.model = os.path.join(d, 'silero_vad.onnx')
config.silero_vad.min_silence_duration = min_silence_duration
config.sample_rate = samplerate
config.provider = args.asr_provider
config.num_threads = args.threads
vad = sherpa_onnx.VoiceActivityDetector(
config,
buffer_size_in_seconds=buffer_size_in_seconds)
return vad
async def start_asr_stream(samplerate: int, args) -> ASRStream:
"""
Start a ASR stream
"""
stream = ASRStream(load_asr_engine(samplerate, args), samplerate)
await stream.start()
return stream