|
from typing import * |
|
import os |
|
import time |
|
import sherpa_onnx |
|
import logging |
|
import numpy as np |
|
import asyncio |
|
import time |
|
import soundfile |
|
from scipy.signal import resample |
|
import io |
|
import re |
|
|
|
logger = logging.getLogger(__file__) |
|
|
|
splitter = re.compile(r'[,,。.!?!?;;、\n]') |
|
_tts_engines = {} |
|
|
|
tts_configs = { |
|
'vits-zh-hf-theresa': { |
|
'model': 'theresa.onnx', |
|
'lexicon': 'lexicon.txt', |
|
'dict_dir': 'dict', |
|
'tokens': 'tokens.txt', |
|
'sample_rate': 22050, |
|
|
|
}, |
|
'vits-melo-tts-zh_en': { |
|
'model': 'model.onnx', |
|
'lexicon': 'lexicon.txt', |
|
'dict_dir': 'dict', |
|
'tokens': 'tokens.txt', |
|
'sample_rate': 44100, |
|
'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'], |
|
}, |
|
} |
|
|
|
|
|
def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig: |
|
cfg = tts_configs[name] |
|
fsts = [] |
|
model_dir = os.path.join(model_root, name) |
|
for f in cfg.get('rule_fsts', ''): |
|
fsts.append(os.path.join(model_dir, f)) |
|
tts_rule_fsts = ','.join(fsts) if fsts else '' |
|
|
|
model_config = sherpa_onnx.OfflineTtsModelConfig( |
|
vits=sherpa_onnx.OfflineTtsVitsModelConfig( |
|
model=os.path.join(model_dir, cfg['model']), |
|
lexicon=os.path.join(model_dir, cfg['lexicon']), |
|
dict_dir=os.path.join(model_dir, cfg['dict_dir']), |
|
tokens=os.path.join(model_dir, cfg['tokens']), |
|
), |
|
provider=provider, |
|
debug=0, |
|
num_threads=num_threads, |
|
) |
|
tts_config = sherpa_onnx.OfflineTtsConfig( |
|
model=model_config, |
|
rule_fsts=tts_rule_fsts, |
|
max_num_sentences=max_num_sentences) |
|
|
|
if not tts_config.validate(): |
|
raise ValueError("tts: invalid config") |
|
|
|
return tts_config |
|
|
|
|
|
def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]: |
|
sample_rate = tts_configs[args.tts_model]['sample_rate'] |
|
cache_engine = _tts_engines.get(args.tts_model) |
|
if cache_engine: |
|
return cache_engine, sample_rate |
|
st = time.time() |
|
tts_config = load_tts_model( |
|
args.tts_model, args.models_root, args.tts_provider) |
|
|
|
cache_engine = sherpa_onnx.OfflineTts(tts_config) |
|
elapsed = time.time() - st |
|
logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s") |
|
_tts_engines[args.tts_model] = cache_engine |
|
|
|
return cache_engine, sample_rate |
|
|
|
|
|
class TTSResult: |
|
def __init__(self, pcm_bytes: bytes, finished: bool): |
|
self.pcm_bytes = pcm_bytes |
|
self.finished = finished |
|
self.progress: float = 0.0 |
|
self.elapsed: float = 0.0 |
|
self.audio_duration: float = 0.0 |
|
self.audio_size: int = 0 |
|
|
|
def to_dict(self): |
|
return { |
|
"progress": self.progress, |
|
"elapsed": f'{int(self.elapsed * 1000)}ms', |
|
"duration": f'{self.audio_duration:.2f}s', |
|
"size": self.audio_size |
|
} |
|
|
|
|
|
class TTSStream: |
|
def __init__(self, engine, sid: int, speed: float = 1.0, sample_rate: int = 16000, original_sample_rate: int = 16000): |
|
self.engine = engine |
|
self.sid = sid |
|
self.speed = speed |
|
self.outbuf: asyncio.Queue[TTSResult | None] = asyncio.Queue() |
|
self.is_closed = False |
|
self.target_sample_rate = sample_rate |
|
self.original_sample_rate = original_sample_rate |
|
|
|
def on_process(self, chunk: np.ndarray, progress: float): |
|
if self.is_closed: |
|
return 0 |
|
|
|
|
|
if self.target_sample_rate != self.original_sample_rate: |
|
num_samples = int( |
|
len(chunk) * self.target_sample_rate / self.original_sample_rate) |
|
resampled_chunk = resample(chunk, num_samples) |
|
chunk = resampled_chunk.astype(np.float32) |
|
|
|
scaled_chunk = chunk * 32768.0 |
|
clipped_chunk = np.clip(scaled_chunk, -32768, 32767) |
|
int16_chunk = clipped_chunk.astype(np.int16) |
|
samples = int16_chunk.tobytes() |
|
self.outbuf.put_nowait(TTSResult(samples, False)) |
|
return self.is_closed and 0 or 1 |
|
|
|
async def write(self, text: str, split: bool, pause: float = 0.2): |
|
start = time.time() |
|
if split: |
|
texts = re.split(splitter, text) |
|
else: |
|
texts = [text] |
|
|
|
audio_duration = 0.0 |
|
audio_size = 0 |
|
|
|
for idx, text in enumerate(texts): |
|
text = text.strip() |
|
if not text: |
|
continue |
|
sub_start = time.time() |
|
|
|
audio = await asyncio.to_thread(self.engine.generate, |
|
text, self.sid, self.speed, |
|
self.on_process) |
|
|
|
if not audio or not audio.sample_rate or not audio.samples: |
|
logger.error(f"tts: failed to generate audio for " |
|
f"'{text}' (audio={audio})") |
|
continue |
|
|
|
if split and idx < len(texts) - 1: |
|
noise = np.zeros(int(audio.sample_rate * pause)) |
|
self.on_process(noise, 1.0) |
|
audio.samples = np.concatenate([audio.samples, noise]) |
|
|
|
audio_duration += len(audio.samples) / audio.sample_rate |
|
audio_size += len(audio.samples) |
|
elapsed_seconds = time.time() - sub_start |
|
logger.info(f"tts: generated audio for '{text}', " |
|
f"audio duration: {audio_duration:.2f}s, " |
|
f"elapsed: {elapsed_seconds:.2f}s") |
|
|
|
elapsed_seconds = time.time() - start |
|
logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, " |
|
f"audio duration: {audio_duration:.2f}s") |
|
|
|
r = TTSResult(None, True) |
|
r.elapsed = elapsed_seconds |
|
r.audio_duration = audio_duration |
|
r.progress = 1.0 |
|
r.finished = True |
|
await self.outbuf.put(r) |
|
|
|
async def close(self): |
|
self.is_closed = True |
|
self.outbuf.put_nowait(None) |
|
logger.info("tts: stream closed") |
|
|
|
async def read(self) -> TTSResult: |
|
return await self.outbuf.get() |
|
|
|
async def generate(self, text: str) -> io.BytesIO: |
|
start = time.time() |
|
audio = await asyncio.to_thread(self.engine.generate, |
|
text, self.sid, self.speed) |
|
elapsed_seconds = time.time() - start |
|
audio_duration = len(audio.samples) / audio.sample_rate |
|
|
|
logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, " |
|
f"audio duration: {audio_duration:.2f}s, " |
|
f"sample rate: {audio.sample_rate}") |
|
|
|
if self.target_sample_rate != audio.sample_rate: |
|
audio.samples = resample(audio.samples, |
|
int(len(audio.samples) * self.target_sample_rate / audio.sample_rate)) |
|
audio.sample_rate = self.target_sample_rate |
|
|
|
output = io.BytesIO() |
|
soundfile.write(output, |
|
audio.samples, |
|
samplerate=audio.sample_rate, |
|
subtype="PCM_16", |
|
format="WAV") |
|
output.seek(0) |
|
return output |
|
|
|
|
|
async def start_tts_stream(sid: int, sample_rate: int, speed: float, args) -> TTSStream: |
|
engine, original_sample_rate = get_tts_engine(args) |
|
return TTSStream(engine, sid, speed, sample_rate, original_sample_rate) |
|
|