dh1 / tts.py
cuio's picture
Upload 11 files
da8e0c5 verified
raw
history blame
7.54 kB
from typing import *
import os
import time
import sherpa_onnx
import logging
import numpy as np
import asyncio
import time
import soundfile
from scipy.signal import resample
import io
import re
logger = logging.getLogger(__file__)
splitter = re.compile(r'[,,。.!?!?;;、\n]')
_tts_engines = {}
tts_configs = {
'vits-zh-hf-theresa': {
'model': 'theresa.onnx',
'lexicon': 'lexicon.txt',
'dict_dir': 'dict',
'tokens': 'tokens.txt',
'sample_rate': 22050,
# 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
},
'vits-melo-tts-zh_en': {
'model': 'model.onnx',
'lexicon': 'lexicon.txt',
'dict_dir': 'dict',
'tokens': 'tokens.txt',
'sample_rate': 44100,
'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
},
}
def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig:
cfg = tts_configs[name]
fsts = []
model_dir = os.path.join(model_root, name)
for f in cfg.get('rule_fsts', ''):
fsts.append(os.path.join(model_dir, f))
tts_rule_fsts = ','.join(fsts) if fsts else ''
model_config = sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=os.path.join(model_dir, cfg['model']),
lexicon=os.path.join(model_dir, cfg['lexicon']),
dict_dir=os.path.join(model_dir, cfg['dict_dir']),
tokens=os.path.join(model_dir, cfg['tokens']),
),
provider=provider,
debug=0,
num_threads=num_threads,
)
tts_config = sherpa_onnx.OfflineTtsConfig(
model=model_config,
rule_fsts=tts_rule_fsts,
max_num_sentences=max_num_sentences)
if not tts_config.validate():
raise ValueError("tts: invalid config")
return tts_config
def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]:
sample_rate = tts_configs[args.tts_model]['sample_rate']
cache_engine = _tts_engines.get(args.tts_model)
if cache_engine:
return cache_engine, sample_rate
st = time.time()
tts_config = load_tts_model(
args.tts_model, args.models_root, args.tts_provider)
cache_engine = sherpa_onnx.OfflineTts(tts_config)
elapsed = time.time() - st
logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s")
_tts_engines[args.tts_model] = cache_engine
return cache_engine, sample_rate
class TTSResult:
def __init__(self, pcm_bytes: bytes, finished: bool):
self.pcm_bytes = pcm_bytes
self.finished = finished
self.progress: float = 0.0
self.elapsed: float = 0.0
self.audio_duration: float = 0.0
self.audio_size: int = 0
def to_dict(self):
return {
"progress": self.progress,
"elapsed": f'{int(self.elapsed * 1000)}ms',
"duration": f'{self.audio_duration:.2f}s',
"size": self.audio_size
}
class TTSStream:
def __init__(self, engine, sid: int, speed: float = 1.0, sample_rate: int = 16000, original_sample_rate: int = 16000):
self.engine = engine
self.sid = sid
self.speed = speed
self.outbuf: asyncio.Queue[TTSResult | None] = asyncio.Queue()
self.is_closed = False
self.target_sample_rate = sample_rate
self.original_sample_rate = original_sample_rate
def on_process(self, chunk: np.ndarray, progress: float):
if self.is_closed:
return 0
# resample to target sample rate
if self.target_sample_rate != self.original_sample_rate:
num_samples = int(
len(chunk) * self.target_sample_rate / self.original_sample_rate)
resampled_chunk = resample(chunk, num_samples)
chunk = resampled_chunk.astype(np.float32)
scaled_chunk = chunk * 32768.0
clipped_chunk = np.clip(scaled_chunk, -32768, 32767)
int16_chunk = clipped_chunk.astype(np.int16)
samples = int16_chunk.tobytes()
self.outbuf.put_nowait(TTSResult(samples, False))
return self.is_closed and 0 or 1
async def write(self, text: str, split: bool, pause: float = 0.2):
start = time.time()
if split:
texts = re.split(splitter, text)
else:
texts = [text]
audio_duration = 0.0
audio_size = 0
for idx, text in enumerate(texts):
text = text.strip()
if not text:
continue
sub_start = time.time()
audio = await asyncio.to_thread(self.engine.generate,
text, self.sid, self.speed,
self.on_process)
if not audio or not audio.sample_rate or not audio.samples:
logger.error(f"tts: failed to generate audio for "
f"'{text}' (audio={audio})")
continue
if split and idx < len(texts) - 1: # add a pause between sentences
noise = np.zeros(int(audio.sample_rate * pause))
self.on_process(noise, 1.0)
audio.samples = np.concatenate([audio.samples, noise])
audio_duration += len(audio.samples) / audio.sample_rate
audio_size += len(audio.samples)
elapsed_seconds = time.time() - sub_start
logger.info(f"tts: generated audio for '{text}', "
f"audio duration: {audio_duration:.2f}s, "
f"elapsed: {elapsed_seconds:.2f}s")
elapsed_seconds = time.time() - start
logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
f"audio duration: {audio_duration:.2f}s")
r = TTSResult(None, True)
r.elapsed = elapsed_seconds
r.audio_duration = audio_duration
r.progress = 1.0
r.finished = True
await self.outbuf.put(r)
async def close(self):
self.is_closed = True
self.outbuf.put_nowait(None)
logger.info("tts: stream closed")
async def read(self) -> TTSResult:
return await self.outbuf.get()
async def generate(self, text: str) -> io.BytesIO:
start = time.time()
audio = await asyncio.to_thread(self.engine.generate,
text, self.sid, self.speed)
elapsed_seconds = time.time() - start
audio_duration = len(audio.samples) / audio.sample_rate
logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
f"audio duration: {audio_duration:.2f}s, "
f"sample rate: {audio.sample_rate}")
if self.target_sample_rate != audio.sample_rate:
audio.samples = resample(audio.samples,
int(len(audio.samples) * self.target_sample_rate / audio.sample_rate))
audio.sample_rate = self.target_sample_rate
output = io.BytesIO()
soundfile.write(output,
audio.samples,
samplerate=audio.sample_rate,
subtype="PCM_16",
format="WAV")
output.seek(0)
return output
async def start_tts_stream(sid: int, sample_rate: int, speed: float, args) -> TTSStream:
engine, original_sample_rate = get_tts_engine(args)
return TTSStream(engine, sid, speed, sample_rate, original_sample_rate)