Upload 11 files
Browse files- app.js +99 -0
- app.py +194 -0
- asr.py +233 -0
- audio_process.js +45 -0
- index.html +179 -0
- record.svg +1 -0
- requirements.txt +7 -0
- sherpa_examples.py +274 -0
- speaking.svg +1 -0
- tts.py +216 -0
- voice.png +0 -0
app.js
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const demoapp = {
|
2 |
+
text: '讲个冷笑话吧,要很好笑的那种。',
|
3 |
+
recording: false,
|
4 |
+
asrWS: null,
|
5 |
+
currentText: null,
|
6 |
+
disabled: false,
|
7 |
+
elapsedTime: null,
|
8 |
+
logs: [{ idx: 0, text: 'Happily here at ruzhila.cn.' }],
|
9 |
+
async init() {
|
10 |
+
},
|
11 |
+
async dotts() {
|
12 |
+
let audioContext = new AudioContext({ sampleRate: 16000 })
|
13 |
+
await audioContext.audioWorklet.addModule('./audio_process.js')
|
14 |
+
|
15 |
+
const ws = new WebSocket('/tts');
|
16 |
+
ws.onopen = () => {
|
17 |
+
ws.send(this.text);
|
18 |
+
};
|
19 |
+
const playNode = new AudioWorkletNode(audioContext, 'play-audio-processor');
|
20 |
+
playNode.connect(audioContext.destination);
|
21 |
+
|
22 |
+
this.disabled = true;
|
23 |
+
ws.onmessage = async (e) => {
|
24 |
+
if (e.data instanceof Blob) {
|
25 |
+
e.data.arrayBuffer().then((arrayBuffer) => {
|
26 |
+
const int16Array = new Int16Array(arrayBuffer);
|
27 |
+
let float32Array = new Float32Array(int16Array.length);
|
28 |
+
for (let i = 0; i < int16Array.length; i++) {
|
29 |
+
float32Array[i] = int16Array[i] / 32768.;
|
30 |
+
}
|
31 |
+
playNode.port.postMessage({ message: 'audioData', audioData: float32Array });
|
32 |
+
});
|
33 |
+
} else {
|
34 |
+
this.elapsedTime = JSON.parse(e.data)?.elapsed;
|
35 |
+
this.disabled = false;
|
36 |
+
}
|
37 |
+
}
|
38 |
+
},
|
39 |
+
|
40 |
+
async stopasr() {
|
41 |
+
if (!this.asrWS) {
|
42 |
+
return;
|
43 |
+
}
|
44 |
+
this.asrWS.close();
|
45 |
+
this.asrWS = null;
|
46 |
+
this.recording = false;
|
47 |
+
if (this.currentText) {
|
48 |
+
this.logs.push({ idx: this.logs.length + 1, text: this.currentText });
|
49 |
+
}
|
50 |
+
this.currentText = null;
|
51 |
+
|
52 |
+
},
|
53 |
+
|
54 |
+
async doasr() {
|
55 |
+
const audioConstraints = {
|
56 |
+
video: false,
|
57 |
+
audio: true,
|
58 |
+
};
|
59 |
+
|
60 |
+
const mediaStream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
61 |
+
|
62 |
+
const ws = new WebSocket('/asr');
|
63 |
+
let currentMessage = '';
|
64 |
+
|
65 |
+
ws.onopen = () => {
|
66 |
+
this.logs = [];
|
67 |
+
};
|
68 |
+
|
69 |
+
ws.onmessage = (e) => {
|
70 |
+
const data = JSON.parse(e.data);
|
71 |
+
const { text, finished, idx } = data;
|
72 |
+
|
73 |
+
currentMessage = text;
|
74 |
+
this.currentText = text
|
75 |
+
|
76 |
+
if (finished) {
|
77 |
+
this.logs.push({ text: currentMessage, idx: idx });
|
78 |
+
currentMessage = '';
|
79 |
+
this.currentText = null
|
80 |
+
}
|
81 |
+
};
|
82 |
+
|
83 |
+
let audioContext = new AudioContext({ sampleRate: 16000 })
|
84 |
+
await audioContext.audioWorklet.addModule('./audio_process.js')
|
85 |
+
|
86 |
+
const recordNode = new AudioWorkletNode(audioContext, 'record-audio-processor');
|
87 |
+
recordNode.connect(audioContext.destination);
|
88 |
+
recordNode.port.onmessage = (event) => {
|
89 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
90 |
+
const int16Array = event.data.data;
|
91 |
+
ws.send(int16Array.buffer);
|
92 |
+
}
|
93 |
+
}
|
94 |
+
const source = audioContext.createMediaStreamSource(mediaStream);
|
95 |
+
source.connect(recordNode);
|
96 |
+
this.asrWS = ws;
|
97 |
+
this.recording = true;
|
98 |
+
}
|
99 |
+
}
|
app.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query
|
3 |
+
from fastapi.responses import HTMLResponse, StreamingResponse
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
import asyncio
|
6 |
+
import logging
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
import uvicorn
|
9 |
+
from voiceapi.tts import TTSResult, start_tts_stream, TTSStream
|
10 |
+
from voiceapi.asr import start_asr_stream, ASRStream, ASRResult
|
11 |
+
import logging
|
12 |
+
import argparse
|
13 |
+
import os
|
14 |
+
|
15 |
+
app = FastAPI()
|
16 |
+
logger = logging.getLogger(__file__)
|
17 |
+
|
18 |
+
|
19 |
+
@app.websocket("/asr")
|
20 |
+
async def websocket_asr(websocket: WebSocket,
|
21 |
+
samplerate: int = Query(16000, title="Sample Rate",
|
22 |
+
description="The sample rate of the audio."),):
|
23 |
+
await websocket.accept()
|
24 |
+
|
25 |
+
asr_stream: ASRStream = await start_asr_stream(samplerate, args)
|
26 |
+
if not asr_stream:
|
27 |
+
logger.error("failed to start ASR stream")
|
28 |
+
await websocket.close()
|
29 |
+
return
|
30 |
+
|
31 |
+
async def task_recv_pcm():
|
32 |
+
while True:
|
33 |
+
pcm_bytes = await websocket.receive_bytes()
|
34 |
+
if not pcm_bytes:
|
35 |
+
return
|
36 |
+
await asr_stream.write(pcm_bytes)
|
37 |
+
|
38 |
+
async def task_send_result():
|
39 |
+
while True:
|
40 |
+
result: ASRResult = await asr_stream.read()
|
41 |
+
if not result:
|
42 |
+
return
|
43 |
+
await websocket.send_json(result.to_dict())
|
44 |
+
try:
|
45 |
+
await asyncio.gather(task_recv_pcm(), task_send_result())
|
46 |
+
except WebSocketDisconnect:
|
47 |
+
logger.info("asr: disconnected")
|
48 |
+
finally:
|
49 |
+
await asr_stream.close()
|
50 |
+
|
51 |
+
|
52 |
+
@app.websocket("/tts")
|
53 |
+
async def websocket_tts(websocket: WebSocket,
|
54 |
+
samplerate: int = Query(16000,
|
55 |
+
title="Sample Rate",
|
56 |
+
description="The sample rate of the generated audio."),
|
57 |
+
interrupt: bool = Query(True,
|
58 |
+
title="Interrupt",
|
59 |
+
description="Interrupt the current TTS stream when a new text is received."),
|
60 |
+
sid: int = Query(0,
|
61 |
+
title="Speaker ID",
|
62 |
+
description="The ID of the speaker to use for TTS."),
|
63 |
+
chunk_size: int = Query(1024,
|
64 |
+
title="Chunk Size",
|
65 |
+
description="The size of the chunk to send to the client."),
|
66 |
+
speed: float = Query(1.0,
|
67 |
+
title="Speed",
|
68 |
+
description="The speed of the generated audio."),
|
69 |
+
split: bool = Query(True,
|
70 |
+
title="Split",
|
71 |
+
description="Split the text into sentences.")):
|
72 |
+
|
73 |
+
await websocket.accept()
|
74 |
+
tts_stream: TTSStream = None
|
75 |
+
|
76 |
+
async def task_recv_text():
|
77 |
+
nonlocal tts_stream
|
78 |
+
while True:
|
79 |
+
text = await websocket.receive_text()
|
80 |
+
if not text:
|
81 |
+
return
|
82 |
+
|
83 |
+
if interrupt or not tts_stream:
|
84 |
+
if tts_stream:
|
85 |
+
await tts_stream.close()
|
86 |
+
logger.info("tts: stream interrupt")
|
87 |
+
|
88 |
+
tts_stream = await start_tts_stream(sid, samplerate, speed, args)
|
89 |
+
if not tts_stream:
|
90 |
+
logger.error("tts: failed to allocate tts stream")
|
91 |
+
await websocket.close()
|
92 |
+
return
|
93 |
+
logger.info(f"tts: received: {text} (split={split})")
|
94 |
+
await tts_stream.write(text, split)
|
95 |
+
|
96 |
+
async def task_send_pcm():
|
97 |
+
nonlocal tts_stream
|
98 |
+
while not tts_stream:
|
99 |
+
# wait for tts stream to be created
|
100 |
+
await asyncio.sleep(0.1)
|
101 |
+
|
102 |
+
while True:
|
103 |
+
result: TTSResult = await tts_stream.read()
|
104 |
+
if not result:
|
105 |
+
return
|
106 |
+
|
107 |
+
if result.finished:
|
108 |
+
await websocket.send_json(result.to_dict())
|
109 |
+
else:
|
110 |
+
for i in range(0, len(result.pcm_bytes), chunk_size):
|
111 |
+
await websocket.send_bytes(result.pcm_bytes[i:i+chunk_size])
|
112 |
+
|
113 |
+
try:
|
114 |
+
await asyncio.gather(task_recv_text(), task_send_pcm())
|
115 |
+
except WebSocketDisconnect:
|
116 |
+
logger.info("tts: disconnected")
|
117 |
+
finally:
|
118 |
+
if tts_stream:
|
119 |
+
await tts_stream.close()
|
120 |
+
|
121 |
+
|
122 |
+
class TTSRequest(BaseModel):
|
123 |
+
text: str = Field(..., title="Text",
|
124 |
+
description="The text to be converted to speech.",
|
125 |
+
examples=["Hello, world!"])
|
126 |
+
sid: int = Field(0, title="Speaker ID",
|
127 |
+
description="The ID of the speaker to use for TTS.")
|
128 |
+
samplerate: int = Field(16000, title="Sample Rate",
|
129 |
+
description="The sample rate of the generated audio.")
|
130 |
+
speed: float = Field(1.0, title="Speed",
|
131 |
+
description="The speed of the generated audio.")
|
132 |
+
|
133 |
+
|
134 |
+
@ app.post("/tts",
|
135 |
+
description="Generate speech audio from text.",
|
136 |
+
response_class=StreamingResponse, responses={200: {"content": {"audio/wav": {}}}})
|
137 |
+
async def tts_generate(req: TTSRequest):
|
138 |
+
if not req.text:
|
139 |
+
raise HTTPException(status_code=400, detail="text is required")
|
140 |
+
|
141 |
+
tts_stream = await start_tts_stream(req.sid, req.samplerate, req.speed, args)
|
142 |
+
if not tts_stream:
|
143 |
+
raise HTTPException(
|
144 |
+
status_code=500, detail="failed to start TTS stream")
|
145 |
+
|
146 |
+
r = await tts_stream.generate(req.text)
|
147 |
+
return StreamingResponse(r, media_type="audio/wav")
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
models_root = './models'
|
152 |
+
|
153 |
+
for d in ['.', '..', '../..']:
|
154 |
+
if os.path.isdir(f'{d}/models'):
|
155 |
+
models_root = f'{d}/models'
|
156 |
+
break
|
157 |
+
|
158 |
+
parser = argparse.ArgumentParser()
|
159 |
+
parser.add_argument("--port", type=int, default=8000, help="port number")
|
160 |
+
parser.add_argument("--addr", type=str,
|
161 |
+
default="0.0.0.0", help="serve address")
|
162 |
+
|
163 |
+
parser.add_argument("--asr-provider", type=str,
|
164 |
+
default="cpu", help="asr provider, cpu or cuda")
|
165 |
+
parser.add_argument("--tts-provider", type=str,
|
166 |
+
default="cpu", help="tts provider, cpu or cuda")
|
167 |
+
|
168 |
+
parser.add_argument("--threads", type=int, default=2,
|
169 |
+
help="number of threads")
|
170 |
+
|
171 |
+
parser.add_argument("--models-root", type=str, default=models_root,
|
172 |
+
help="model root directory")
|
173 |
+
|
174 |
+
parser.add_argument("--asr-model", type=str, default='sensevoice',
|
175 |
+
help="ASR model name: zipformer-bilingual, sensevoice, paraformer-trilingual, paraformer-en")
|
176 |
+
|
177 |
+
parser.add_argument("--asr-lang", type=str, default='zh',
|
178 |
+
help="ASR language, zh, en, ja, ko, yue")
|
179 |
+
|
180 |
+
parser.add_argument("--tts-model", type=str, default='vits-zh-hf-theresa',
|
181 |
+
help="TTS model name: vits-zh-hf-theresa, vits-melo-tts-zh_en")
|
182 |
+
|
183 |
+
args = parser.parse_args()
|
184 |
+
|
185 |
+
if args.tts_model == 'vits-melo-tts-zh_en' and args.tts_provider == 'cuda':
|
186 |
+
logger.warning(
|
187 |
+
"vits-melo-tts-zh_en does not support CUDA fallback to CPU")
|
188 |
+
args.tts_provider = 'cpu'
|
189 |
+
|
190 |
+
app.mount("/", app=StaticFiles(directory="./assets", html=True), name="assets")
|
191 |
+
|
192 |
+
logging.basicConfig(format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s',
|
193 |
+
level=logging.INFO)
|
194 |
+
uvicorn.run(app, host=args.addr, port=args.port)
|
asr.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
import logging
|
5 |
+
import sherpa_onnx
|
6 |
+
import os
|
7 |
+
import asyncio
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
logger = logging.getLogger(__file__)
|
11 |
+
_asr_engines = {}
|
12 |
+
|
13 |
+
|
14 |
+
class ASRResult:
|
15 |
+
def __init__(self, text: str, finished: bool, idx: int):
|
16 |
+
self.text = text
|
17 |
+
self.finished = finished
|
18 |
+
self.idx = idx
|
19 |
+
|
20 |
+
def to_dict(self):
|
21 |
+
return {"text": self.text, "finished": self.finished, "idx": self.idx}
|
22 |
+
|
23 |
+
|
24 |
+
class ASRStream:
|
25 |
+
def __init__(self, recognizer: Union[sherpa_onnx.OnlineRecognizer | sherpa_onnx.OfflineRecognizer], sample_rate: int) -> None:
|
26 |
+
self.recognizer = recognizer
|
27 |
+
self.inbuf = asyncio.Queue()
|
28 |
+
self.outbuf = asyncio.Queue()
|
29 |
+
self.sample_rate = sample_rate
|
30 |
+
self.is_closed = False
|
31 |
+
self.online = isinstance(recognizer, sherpa_onnx.OnlineRecognizer)
|
32 |
+
|
33 |
+
async def start(self):
|
34 |
+
if self.online:
|
35 |
+
asyncio.create_task(self.run_online())
|
36 |
+
else:
|
37 |
+
asyncio.create_task(self.run_offline())
|
38 |
+
|
39 |
+
async def run_online(self):
|
40 |
+
stream = self.recognizer.create_stream()
|
41 |
+
last_result = ""
|
42 |
+
segment_id = 0
|
43 |
+
logger.info('asr: start real-time recognizer')
|
44 |
+
while not self.is_closed:
|
45 |
+
samples = await self.inbuf.get()
|
46 |
+
stream.accept_waveform(self.sample_rate, samples)
|
47 |
+
while self.recognizer.is_ready(stream):
|
48 |
+
self.recognizer.decode_stream(stream)
|
49 |
+
|
50 |
+
is_endpoint = self.recognizer.is_endpoint(stream)
|
51 |
+
result = self.recognizer.get_result(stream)
|
52 |
+
|
53 |
+
if result and (last_result != result):
|
54 |
+
last_result = result
|
55 |
+
logger.info(f' > {segment_id}:{result}')
|
56 |
+
self.outbuf.put_nowait(
|
57 |
+
ASRResult(result, False, segment_id))
|
58 |
+
|
59 |
+
if is_endpoint:
|
60 |
+
if result:
|
61 |
+
logger.info(f'{segment_id}: {result}')
|
62 |
+
self.outbuf.put_nowait(
|
63 |
+
ASRResult(result, True, segment_id))
|
64 |
+
segment_id += 1
|
65 |
+
self.recognizer.reset(stream)
|
66 |
+
|
67 |
+
async def run_offline(self):
|
68 |
+
vad = _asr_engines['vad']
|
69 |
+
segment_id = 0
|
70 |
+
st = None
|
71 |
+
while not self.is_closed:
|
72 |
+
samples = await self.inbuf.get()
|
73 |
+
vad.accept_waveform(samples)
|
74 |
+
while not vad.empty():
|
75 |
+
if not st:
|
76 |
+
st = time.time()
|
77 |
+
stream = self.recognizer.create_stream()
|
78 |
+
stream.accept_waveform(self.sample_rate, vad.front.samples)
|
79 |
+
|
80 |
+
vad.pop()
|
81 |
+
self.recognizer.decode_stream(stream)
|
82 |
+
|
83 |
+
result = stream.result.text.strip()
|
84 |
+
if result:
|
85 |
+
duration = time.time() - st
|
86 |
+
logger.info(f'{segment_id}:{result} ({duration:.2f}s)')
|
87 |
+
self.outbuf.put_nowait(ASRResult(result, True, segment_id))
|
88 |
+
segment_id += 1
|
89 |
+
st = None
|
90 |
+
|
91 |
+
async def close(self):
|
92 |
+
self.is_closed = True
|
93 |
+
self.outbuf.put_nowait(None)
|
94 |
+
|
95 |
+
async def write(self, pcm_bytes: bytes):
|
96 |
+
pcm_data = np.frombuffer(pcm_bytes, dtype=np.int16)
|
97 |
+
samples = pcm_data.astype(np.float32) / 32768.0
|
98 |
+
self.inbuf.put_nowait(samples)
|
99 |
+
|
100 |
+
async def read(self) -> ASRResult:
|
101 |
+
return await self.outbuf.get()
|
102 |
+
|
103 |
+
|
104 |
+
def create_zipformer(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
|
105 |
+
d = os.path.join(
|
106 |
+
args.models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20')
|
107 |
+
if not os.path.exists(d):
|
108 |
+
raise ValueError(f"asr: model not found {d}")
|
109 |
+
|
110 |
+
encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx")
|
111 |
+
decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx")
|
112 |
+
joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx")
|
113 |
+
tokens = os.path.join(d, "tokens.txt")
|
114 |
+
|
115 |
+
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
116 |
+
tokens=tokens,
|
117 |
+
encoder=encoder,
|
118 |
+
decoder=decoder,
|
119 |
+
joiner=joiner,
|
120 |
+
provider=args.asr_provider,
|
121 |
+
num_threads=args.threads,
|
122 |
+
sample_rate=samplerate,
|
123 |
+
feature_dim=80,
|
124 |
+
enable_endpoint_detection=True,
|
125 |
+
rule1_min_trailing_silence=2.4,
|
126 |
+
rule2_min_trailing_silence=1.2,
|
127 |
+
rule3_min_utterance_length=20, # it essentially disables this rule
|
128 |
+
)
|
129 |
+
return recognizer
|
130 |
+
|
131 |
+
|
132 |
+
def create_sensevoice(samplerate: int, args) -> sherpa_onnx.OfflineRecognizer:
|
133 |
+
d = os.path.join(args.models_root,
|
134 |
+
'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17')
|
135 |
+
|
136 |
+
if not os.path.exists(d):
|
137 |
+
raise ValueError(f"asr: model not found {d}")
|
138 |
+
|
139 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
140 |
+
model=os.path.join(d, 'model.onnx'),
|
141 |
+
tokens=os.path.join(d, 'tokens.txt'),
|
142 |
+
num_threads=args.threads,
|
143 |
+
sample_rate=samplerate,
|
144 |
+
use_itn=True,
|
145 |
+
debug=0,
|
146 |
+
language=args.asr_lang,
|
147 |
+
)
|
148 |
+
return recognizer
|
149 |
+
|
150 |
+
|
151 |
+
def create_paraformer_trilingual(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
|
152 |
+
d = os.path.join(
|
153 |
+
args.models_root, 'sherpa-onnx-paraformer-trilingual-zh-cantonese-en')
|
154 |
+
if not os.path.exists(d):
|
155 |
+
raise ValueError(f"asr: model not found {d}")
|
156 |
+
|
157 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
158 |
+
paraformer=os.path.join(d, 'model.onnx'),
|
159 |
+
tokens=os.path.join(d, 'tokens.txt'),
|
160 |
+
num_threads=args.threads,
|
161 |
+
sample_rate=samplerate,
|
162 |
+
debug=0,
|
163 |
+
provider=args.asr_provider,
|
164 |
+
)
|
165 |
+
return recognizer
|
166 |
+
|
167 |
+
|
168 |
+
def create_paraformer_en(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
|
169 |
+
d = os.path.join(
|
170 |
+
args.models_root, 'sherpa-onnx-paraformer-en')
|
171 |
+
if not os.path.exists(d):
|
172 |
+
raise ValueError(f"asr: model not found {d}")
|
173 |
+
|
174 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
175 |
+
paraformer=os.path.join(d, 'model.onnx'),
|
176 |
+
tokens=os.path.join(d, 'tokens.txt'),
|
177 |
+
num_threads=args.threads,
|
178 |
+
sample_rate=samplerate,
|
179 |
+
use_itn=True,
|
180 |
+
debug=0,
|
181 |
+
provider=args.asr_provider,
|
182 |
+
)
|
183 |
+
return recognizer
|
184 |
+
|
185 |
+
|
186 |
+
def load_asr_engine(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
|
187 |
+
cache_engine = _asr_engines.get(args.asr_model)
|
188 |
+
if cache_engine:
|
189 |
+
return cache_engine
|
190 |
+
st = time.time()
|
191 |
+
if args.asr_model == 'zipformer-bilingual':
|
192 |
+
cache_engine = create_zipformer(samplerate, args)
|
193 |
+
elif args.asr_model == 'sensevoice':
|
194 |
+
cache_engine = create_sensevoice(samplerate, args)
|
195 |
+
_asr_engines['vad'] = load_vad_engine(samplerate, args)
|
196 |
+
elif args.asr_model == 'paraformer-trilingual':
|
197 |
+
cache_engine = create_paraformer_trilingual(samplerate, args)
|
198 |
+
_asr_engines['vad'] = load_vad_engine(samplerate, args)
|
199 |
+
elif args.asr_model == 'paraformer-en':
|
200 |
+
cache_engine = create_paraformer_en(samplerate, args)
|
201 |
+
_asr_engines['vad'] = load_vad_engine(samplerate, args)
|
202 |
+
else:
|
203 |
+
raise ValueError(f"asr: unknown model {args.asr_model}")
|
204 |
+
_asr_engines[args.asr_model] = cache_engine
|
205 |
+
logger.info(f"asr: engine loaded in {time.time() - st:.2f}s")
|
206 |
+
return cache_engine
|
207 |
+
|
208 |
+
|
209 |
+
def load_vad_engine(samplerate: int, args, min_silence_duration: float = 0.25, buffer_size_in_seconds: int = 100) -> sherpa_onnx.VoiceActivityDetector:
|
210 |
+
config = sherpa_onnx.VadModelConfig()
|
211 |
+
d = os.path.join(args.models_root, 'silero_vad')
|
212 |
+
if not os.path.exists(d):
|
213 |
+
raise ValueError(f"vad: model not found {d}")
|
214 |
+
|
215 |
+
config.silero_vad.model = os.path.join(d, 'silero_vad.onnx')
|
216 |
+
config.silero_vad.min_silence_duration = min_silence_duration
|
217 |
+
config.sample_rate = samplerate
|
218 |
+
config.provider = args.asr_provider
|
219 |
+
config.num_threads = args.threads
|
220 |
+
|
221 |
+
vad = sherpa_onnx.VoiceActivityDetector(
|
222 |
+
config,
|
223 |
+
buffer_size_in_seconds=buffer_size_in_seconds)
|
224 |
+
return vad
|
225 |
+
|
226 |
+
|
227 |
+
async def start_asr_stream(samplerate: int, args) -> ASRStream:
|
228 |
+
"""
|
229 |
+
Start a ASR stream
|
230 |
+
"""
|
231 |
+
stream = ASRStream(load_asr_engine(samplerate, args), samplerate)
|
232 |
+
await stream.start()
|
233 |
+
return stream
|
audio_process.js
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class PlayerAudioProcessor extends AudioWorkletProcessor {
|
2 |
+
constructor() {
|
3 |
+
super();
|
4 |
+
this.buffer = new Float32Array();
|
5 |
+
this.port.onmessage = (event) => {
|
6 |
+
let newFetchedData = new Float32Array(this.buffer.length + event.data.audioData.length);
|
7 |
+
newFetchedData.set(this.buffer, 0);
|
8 |
+
newFetchedData.set(event.data.audioData, this.buffer.length);
|
9 |
+
this.buffer = newFetchedData;
|
10 |
+
};
|
11 |
+
}
|
12 |
+
|
13 |
+
process(inputs, outputs, parameters) {
|
14 |
+
const output = outputs[0];
|
15 |
+
const channel = output[0];
|
16 |
+
const bufferLength = this.buffer.length;
|
17 |
+
for (let i = 0; i < channel.length; i++) {
|
18 |
+
channel[i] = (i < bufferLength) ? this.buffer[i] : 0;
|
19 |
+
}
|
20 |
+
this.buffer = this.buffer.slice(channel.length);
|
21 |
+
return true;
|
22 |
+
}
|
23 |
+
}
|
24 |
+
|
25 |
+
class RecordAudioProcessor extends AudioWorkletProcessor {
|
26 |
+
constructor() {
|
27 |
+
super();
|
28 |
+
}
|
29 |
+
|
30 |
+
process(inputs, outputs, parameters) {
|
31 |
+
const channel = inputs[0][0];
|
32 |
+
if (!channel || channel.length === 0) {
|
33 |
+
return true;
|
34 |
+
}
|
35 |
+
const int16Array = new Int16Array(channel.length);
|
36 |
+
for (let i = 0; i < channel.length; i++) {
|
37 |
+
int16Array[i] = channel[i] * 32767;
|
38 |
+
}
|
39 |
+
this.port.postMessage({ data: int16Array });
|
40 |
+
return true
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
registerProcessor('play-audio-processor', PlayerAudioProcessor);
|
45 |
+
registerProcessor('record-audio-processor', RecordAudioProcessor);
|
index.html
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta charset="UTF-8">
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
<link rel="icon" type="image/svg+xml" href="./voice.png" />
|
8 |
+
<script src="//cdn.tailwindcss.com?plugins=forms"></script>
|
9 |
+
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@latest/dist/tailwind.min.css" rel="stylesheet">
|
10 |
+
<script src="//cdn.jsdelivr.net/npm/[email protected]/dist/cdn.min.js" defer></script>
|
11 |
+
<script src="./app.js"></script>
|
12 |
+
<title>voiceapi demo </title>
|
13 |
+
<style>
|
14 |
+
* {
|
15 |
+
margin: 0;
|
16 |
+
padding: 0;
|
17 |
+
}
|
18 |
+
</style>
|
19 |
+
|
20 |
+
<style type="text/tailwindcss">
|
21 |
+
.label { @apply text-gray-900 w-[50px] lg:w-20 }
|
22 |
+
.title{
|
23 |
+
@apply text-[16px] text-zinc-500 mx-2;
|
24 |
+
}
|
25 |
+
|
26 |
+
.select { @apply w-full rounded-md h-10 }
|
27 |
+
|
28 |
+
.round { @apply rounded border px-3 p-2 border-slate-300 placeholder-gray-400 placeholder:text-sm
|
29 |
+
focus:bg-white focus:text-gray-900 focus:placeholder-gray-500 focus:outline-none
|
30 |
+
focus:border-zinc-950 focus:border ring-0 focus:ring-0 text-gray-900 }
|
31 |
+
|
32 |
+
.checkbox { @apply ml-2 lg:ml-4 border focus:outline-none ring-0 focus:ring-gray-800 text-gray-900 }
|
33 |
+
.dash{ @apply border border-dashed border-zinc-200 flex flex-grow }
|
34 |
+
|
35 |
+
.button { @apply hover:bg-opacity-90 text-white font-bold py-1.5 px-6 rounded-full cursor-pointer }
|
36 |
+
.card { @apply bg-white shadow-sm rounded-xl border p-4 }
|
37 |
+
|
38 |
+
|
39 |
+
.animate-ping {
|
40 |
+
animation: ping 2s cubic-bezier(0.5, 0.4, 0.2, 1) infinite;
|
41 |
+
}
|
42 |
+
|
43 |
+
@keyframes ping {
|
44 |
+
0% {
|
45 |
+
transform: scale(1);
|
46 |
+
opacity: 1;
|
47 |
+
}
|
48 |
+
50% {
|
49 |
+
transform: scale(1.2);
|
50 |
+
opacity: 0.7;
|
51 |
+
}
|
52 |
+
100% {
|
53 |
+
transform: scale(1);
|
54 |
+
opacity: 1;
|
55 |
+
}
|
56 |
+
}
|
57 |
+
</style>
|
58 |
+
</head>
|
59 |
+
|
60 |
+
<body>
|
61 |
+
<script>
|
62 |
+
async function initAudioWorklet() {
|
63 |
+
try {
|
64 |
+
// Check for browser support
|
65 |
+
if (!('AudioContext' in window) || !('audioWorklet' in AudioContext.prototype)) {
|
66 |
+
console.error('Audio Worklet API is not supported in this browser.');
|
67 |
+
return;
|
68 |
+
}
|
69 |
+
|
70 |
+
// Initialize AudioContext
|
71 |
+
const audioContext = new AudioContext();
|
72 |
+
|
73 |
+
// Add Audio Worklet module
|
74 |
+
await audioContext.audioWorklet.addModule('./audio_process.js');
|
75 |
+
|
76 |
+
console.log('Audio Worklet module added successfully.');
|
77 |
+
// Your code to use the Audio Worklet goes here
|
78 |
+
|
79 |
+
} catch (error) {
|
80 |
+
console.error('Error initializing Audio Worklet:', error);
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
// Initialize Audio Worklet when the page is loaded
|
85 |
+
window.addEventListener('load', initAudioWorklet);
|
86 |
+
</script>
|
87 |
+
<div x-data="demoapp">
|
88 |
+
<header class="bg-gray-900 py-4 px-5 lg:p-4 lg:px-10 text-white sticky top-0 z-20">
|
89 |
+
<div class="flex w-full justify-between items-center">
|
90 |
+
<p class="gap-x-3">
|
91 |
+
<span>VoiceAPI Demo</span> /
|
92 |
+
<a href="https://ruzhila.cn/?from=voiceapi_demo">ruzhila.cn</a>
|
93 |
+
</p>
|
94 |
+
<a target="_blank" href="https://github.com/ruzhila/voiceapi" class="hover:cursor-pointer">
|
95 |
+
<svg t="1724996252746" class="icon" viewBox="0 0 1024 1024" version="1.1"
|
96 |
+
xmlns="http://www.w3.org/2000/svg" p-id="" width="25" height="25">
|
97 |
+
<path
|
98 |
+
d="M512 12.64c-282.752 0-512 229.216-512 512 0 226.208 146.72 418.144 350.144 485.824 25.6 4.736 35.008-11.104 35.008-24.64 0-12.192-0.48-52.544-0.704-95.328-142.464 30.976-172.512-60.416-172.512-60.416-23.296-59.168-56.832-74.912-56.832-74.912-46.464-31.776 3.52-31.136 3.52-31.136 51.392 3.616 78.464 52.768 78.464 52.768 45.664 78.272 119.776 55.648 148.992 42.56 4.576-33.088 17.856-55.68 32.512-68.48-113.728-12.928-233.28-56.864-233.28-253.024 0-55.904 20-101.568 52.768-137.44-5.312-12.896-22.848-64.96 4.96-135.488 0 0 43.008-13.76 140.832 52.48 40.832-11.36 84.64-17.024 128.16-17.248 43.488 0.192 87.328 5.888 128.256 17.248 97.728-66.24 140.64-52.48 140.64-52.48 27.872 70.528 10.336 122.592 5.024 135.488 32.832 35.84 52.704 81.536 52.704 137.44 0 196.64-119.776 239.936-233.792 252.64 18.368 15.904 34.72 47.04 34.72 94.816 0 68.512-0.608 123.648-0.608 140.512 0 13.632 9.216 29.6 35.168 24.576 203.328-67.776 349.856-259.616 349.856-485.76 0-282.784-229.248-512-512-512z"
|
99 |
+
fill="#ffffff"></path>
|
100 |
+
</svg>
|
101 |
+
</a>
|
102 |
+
</div>
|
103 |
+
</header>
|
104 |
+
|
105 |
+
<div class="flex px-6 gap-x-10 w-full max-w-7xl mx-auto">
|
106 |
+
<div class="relative flex flex-col items-center w-1/3 py-10">
|
107 |
+
<div class="w-full">
|
108 |
+
<textarea x-model="text" class="round p-4 w-full h-[36rem] text-sm"
|
109 |
+
placeholder="Enter text here"></textarea>
|
110 |
+
</div>
|
111 |
+
|
112 |
+
<div>
|
113 |
+
<button @click="dotts" :disabled="disabled"
|
114 |
+
class="button bg-gray-900 flex items-center gap-x-2 mt-6">
|
115 |
+
<span>Speak</span>
|
116 |
+
<svg t="1726215464577" class="icon" viewBox="0 0 1024 1024" version="1.1"
|
117 |
+
xmlns="http://www.w3.org/2000/svg" p-id="4263" width="20" height="20">
|
118 |
+
<path
|
119 |
+
d="M830.450526 853.759999q-11.722105 8.791579-27.351579 8.791579-19.536842 0-33.701053-14.164211t-14.164211-33.701053q0-21.490526
|
120 |
+
16.606316-36.143158 0.976842-0.976842 1.953684-1.465263t1.953684-1.465263l0.976842-0.976842q27.351579-18.56 50.795789-43.957895t41.027368-55.191579 27.351579-63.494737 9.768421-69.84421q0-73.263158-37.12-133.827368t-92.8-99.637895q-20.513684-14.652632-20.513684-39.073684 0-19.536842 14.164211-33.701053t33.701053-14.164211q16.606316 0 29.305263 10.745263 36.143158 25.397895 67.402105 59.098947t53.726316 73.263158 35.166316 84.496842 12.698947 92.8q0 48.842105-12.698947 93.776842t-35.654737 84.985263-54.214737 73.751579-68.378947 59.098947zM775.747368 415.157894q20.513684 28.328421 32.72421 57.145263t12.210526 69.84421q0 39.073684-12.698947 70.332632t-32.235789 56.656842q-7.814737 10.745263-16.606316 19.048421t-22.467368 8.303158q-17.583158 0-29.793684-12.698947t-12.210526-30.282105q0-7.814737 2.930526-15.629474l-0.976842 0q4.884211-10.745263 11.722105-20.513684t13.187368-20.025263 10.745263-23.444211 4.395789-31.747368q0-17.583158-4.395789-30.770526t-10.745263-23.932632-13.187368-20.513684-10.745263-20.513684q-2.930526-6.837895-2.930526-15.629474 0-17.583158 12.210526-30.282105t29.793684-12.698947q13.675789 0 22.467368 8.303158t16.606316 19.048421zM460.227368 995.402104q-49.818947-44.934737-105.498947-93.776842t-103.545263-89.869474q-55.68-46.888421-111.36-92.8-10.745263 0.976842-21.490526 0.976842-8.791579 0.976842-18.56 0.976842l-16.606316 0q-26.374737 0-42.981053-16.117895t-16.606316-38.585263l0-246.16421 0.976842 0-0.976842-0.976842q0-27.351579 17.094737-44.934737t42.492632-17.583158l55.68 0q89.869474-76.193684 163.132631-136.757895 31.258947-26.374737 61.541053-51.28421t54.703158-45.423158 41.027368-34.189474 20.513684-16.606316q29.305263-21.490526 47.376842-19.536842t28.328421 17.583158 14.164211 38.096842 3.907368 41.027368l0 788.311578 0 2.930526q0 18.56-6.837895 39.562105t-21.002105 33.212632-35.654737 10.256842-49.818947-28.328421z"
|
121 |
+
p-id="4264" fill="#ffffff"></path>
|
122 |
+
</svg>
|
123 |
+
</button>
|
124 |
+
</div>
|
125 |
+
<template x-if="elapsedTime">
|
126 |
+
<p x-text="`elapsedTime: ${elapsedTime}`" class="mt-4 text-sm text-gray-600 "></p>
|
127 |
+
</template>
|
128 |
+
</div>
|
129 |
+
|
130 |
+
<!-- recording -->
|
131 |
+
<div class="w-full flex-grow h-[calc(100vh-10rem)] xl:pl-10 py-10">
|
132 |
+
|
133 |
+
<div
|
134 |
+
class="rounded border border-gray-500 p-3 w-full flex flex-col items-end h-[36rem] overflow-y-auto">
|
135 |
+
<template x-for="item in logs">
|
136 |
+
<div class="mt-3 mb-2">
|
137 |
+
<span
|
138 |
+
class="text-white px-4 py-1.5 text-[13px] display-inline-block border border-gray-900 rounded-t-full rounded-l-full bg-gray-900 justify-end w-auto"
|
139 |
+
x-text="item?.text">
|
140 |
+
</span>
|
141 |
+
</div>
|
142 |
+
</template>
|
143 |
+
</div>
|
144 |
+
|
145 |
+
|
146 |
+
<template x-if="currentText">
|
147 |
+
<p x-text="`${currentText} …`" class="text-gray-800 mt-4 text-sm text-center"></p>
|
148 |
+
</template>
|
149 |
+
|
150 |
+
<template x-if="!recording">
|
151 |
+
<div class="flex flex-col gap-y-4 items-center justify-center mt-4">
|
152 |
+
<p @click="doasr"
|
153 |
+
class="mt-2 border border-gray-100 rounded-full duration-300 hover:scale-105 hover:border-gray-400">
|
154 |
+
<img src="./images/record.svg" alt="" class="w-14 h-14 mx-auto cursor-pointer">
|
155 |
+
</p>
|
156 |
+
<p class="text-gray-600">Click to record !</p>
|
157 |
+
</div>
|
158 |
+
</template>
|
159 |
+
|
160 |
+
<template x-if="recording">
|
161 |
+
<div class="flex flex-col items-center justify-center gap-y-4 mt-4">
|
162 |
+
|
163 |
+
<p @click="stopasr"
|
164 |
+
class="mt-2 border border-red-100 rounded-full duration-300 hover:scale-105 hover:border-red-400">
|
165 |
+
<img src="./images/speaking.svg" alt=""
|
166 |
+
class="w-14 h-14 mx-auto cursor-pointer animate-ping">
|
167 |
+
</p>
|
168 |
+
<div class="flex items-center text-gray-600 gap-x-4">
|
169 |
+
<p>Click to stop recording !</p>
|
170 |
+
</div>
|
171 |
+
</div>
|
172 |
+
</template>
|
173 |
+
</div>
|
174 |
+
</div>
|
175 |
+
</div>
|
176 |
+
</div>
|
177 |
+
</body>
|
178 |
+
|
179 |
+
</html>
|
record.svg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sherpa-onnx == 1.10.24
|
2 |
+
soundfile == 0.12.1
|
3 |
+
fastapi == 0.114.1
|
4 |
+
uvicorn == 0.30.6
|
5 |
+
scipy == 1.13.1
|
6 |
+
numpy == 1.26.4
|
7 |
+
websockets == 13.0.1
|
sherpa_examples.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/env python3
|
2 |
+
"""
|
3 |
+
Real-time ASR using microphone
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
import sherpa_onnx
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import struct
|
12 |
+
import asyncio
|
13 |
+
import soundfile
|
14 |
+
|
15 |
+
try:
|
16 |
+
import pyaudio
|
17 |
+
except ImportError:
|
18 |
+
raise ImportError('Please install pyaudio with `pip install pyaudio`')
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
SAMPLE_RATE = 16000
|
22 |
+
|
23 |
+
pactx = pyaudio.PyAudio()
|
24 |
+
models_root: str = None
|
25 |
+
num_threads: int = 1
|
26 |
+
|
27 |
+
|
28 |
+
def create_zipformer(args) -> sherpa_onnx.OnlineRecognizer:
|
29 |
+
d = os.path.join(
|
30 |
+
models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20')
|
31 |
+
encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx")
|
32 |
+
decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx")
|
33 |
+
joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx")
|
34 |
+
tokens = os.path.join(d, "tokens.txt")
|
35 |
+
|
36 |
+
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
37 |
+
tokens=tokens,
|
38 |
+
encoder=encoder,
|
39 |
+
decoder=decoder,
|
40 |
+
joiner=joiner,
|
41 |
+
provider=args.provider,
|
42 |
+
num_threads=num_threads,
|
43 |
+
sample_rate=SAMPLE_RATE,
|
44 |
+
feature_dim=80,
|
45 |
+
enable_endpoint_detection=True,
|
46 |
+
rule1_min_trailing_silence=2.4,
|
47 |
+
rule2_min_trailing_silence=1.2,
|
48 |
+
rule3_min_utterance_length=20, # it essentially disables this rule
|
49 |
+
)
|
50 |
+
return recognizer
|
51 |
+
|
52 |
+
|
53 |
+
def create_sensevoice(args) -> sherpa_onnx.OfflineRecognizer:
|
54 |
+
model = os.path.join(
|
55 |
+
models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'model.onnx')
|
56 |
+
tokens = os.path.join(
|
57 |
+
models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'tokens.txt')
|
58 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
59 |
+
model=model,
|
60 |
+
tokens=tokens,
|
61 |
+
num_threads=num_threads,
|
62 |
+
use_itn=True,
|
63 |
+
debug=0,
|
64 |
+
language=args.lang,
|
65 |
+
)
|
66 |
+
return recognizer
|
67 |
+
|
68 |
+
|
69 |
+
async def run_online(buf, recognizer):
|
70 |
+
stream = recognizer.create_stream()
|
71 |
+
last_result = ""
|
72 |
+
segment_id = 0
|
73 |
+
logger.info('Start real-time recognizer')
|
74 |
+
while True:
|
75 |
+
samples = await buf.get()
|
76 |
+
stream.accept_waveform(SAMPLE_RATE, samples)
|
77 |
+
while recognizer.is_ready(stream):
|
78 |
+
recognizer.decode_stream(stream)
|
79 |
+
|
80 |
+
is_endpoint = recognizer.is_endpoint(stream)
|
81 |
+
result = recognizer.get_result(stream)
|
82 |
+
|
83 |
+
if result and (last_result != result):
|
84 |
+
last_result = result
|
85 |
+
logger.info(f' > {segment_id}:{result}')
|
86 |
+
|
87 |
+
if is_endpoint:
|
88 |
+
if result:
|
89 |
+
logger.info(f'{segment_id}: {result}')
|
90 |
+
segment_id += 1
|
91 |
+
recognizer.reset(stream)
|
92 |
+
|
93 |
+
|
94 |
+
async def run_offline(buf, recognizer):
|
95 |
+
config = sherpa_onnx.VadModelConfig()
|
96 |
+
config.silero_vad.model = os.path.join(
|
97 |
+
models_root, 'silero_vad', 'silero_vad.onnx')
|
98 |
+
config.silero_vad.min_silence_duration = 0.25
|
99 |
+
config.sample_rate = SAMPLE_RATE
|
100 |
+
vad = sherpa_onnx.VoiceActivityDetector(
|
101 |
+
config, buffer_size_in_seconds=100)
|
102 |
+
|
103 |
+
logger.info('Start offline recognizer with VAD')
|
104 |
+
texts = []
|
105 |
+
while True:
|
106 |
+
samples = await buf.get()
|
107 |
+
vad.accept_waveform(samples)
|
108 |
+
while not vad.empty():
|
109 |
+
stream = recognizer.create_stream()
|
110 |
+
stream.accept_waveform(SAMPLE_RATE, vad.front.samples)
|
111 |
+
|
112 |
+
vad.pop()
|
113 |
+
recognizer.decode_stream(stream)
|
114 |
+
|
115 |
+
text = stream.result.text.strip().lower()
|
116 |
+
if len(text):
|
117 |
+
idx = len(texts)
|
118 |
+
texts.append(text)
|
119 |
+
logger.info(f"{idx}: {text}")
|
120 |
+
|
121 |
+
|
122 |
+
async def handle_asr(args):
|
123 |
+
action_func = None
|
124 |
+
if args.model == 'zipformer':
|
125 |
+
recognizer = create_zipformer(args)
|
126 |
+
action_func = run_online
|
127 |
+
elif args.model == 'sensevoice':
|
128 |
+
recognizer = create_sensevoice(args)
|
129 |
+
action_func = run_offline
|
130 |
+
else:
|
131 |
+
raise ValueError(f'Unknown model: {args.model}')
|
132 |
+
buf = asyncio.Queue()
|
133 |
+
recorder_task = asyncio.create_task(run_record(buf))
|
134 |
+
asr_task = asyncio.create_task(action_func(buf, recognizer))
|
135 |
+
await asyncio.gather(asr_task, recorder_task)
|
136 |
+
|
137 |
+
|
138 |
+
async def handle_tts(args):
|
139 |
+
model = os.path.join(
|
140 |
+
models_root, 'vits-melo-tts-zh_en', 'model.onnx')
|
141 |
+
lexicon = os.path.join(
|
142 |
+
models_root, 'vits-melo-tts-zh_en', 'lexicon.txt')
|
143 |
+
dict_dir = os.path.join(
|
144 |
+
models_root, 'vits-melo-tts-zh_en', 'dict')
|
145 |
+
tokens = os.path.join(
|
146 |
+
models_root, 'vits-melo-tts-zh_en', 'tokens.txt')
|
147 |
+
tts_config = sherpa_onnx.OfflineTtsConfig(
|
148 |
+
model=sherpa_onnx.OfflineTtsModelConfig(
|
149 |
+
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
|
150 |
+
model=model,
|
151 |
+
lexicon=lexicon,
|
152 |
+
dict_dir=dict_dir,
|
153 |
+
tokens=tokens,
|
154 |
+
),
|
155 |
+
provider=args.provider,
|
156 |
+
debug=0,
|
157 |
+
num_threads=num_threads,
|
158 |
+
),
|
159 |
+
max_num_sentences=args.max_num_sentences,
|
160 |
+
)
|
161 |
+
if not tts_config.validate():
|
162 |
+
raise ValueError("Please check your config")
|
163 |
+
|
164 |
+
tts = sherpa_onnx.OfflineTts(tts_config)
|
165 |
+
|
166 |
+
start = time.time()
|
167 |
+
audio = tts.generate(args.text, sid=args.sid,
|
168 |
+
speed=args.speed)
|
169 |
+
elapsed_seconds = time.time() - start
|
170 |
+
audio_duration = len(audio.samples) / audio.sample_rate
|
171 |
+
real_time_factor = elapsed_seconds / audio_duration
|
172 |
+
|
173 |
+
if args.output:
|
174 |
+
logger.info(f"Saved to {args.output}")
|
175 |
+
soundfile.write(
|
176 |
+
args.output,
|
177 |
+
audio.samples,
|
178 |
+
samplerate=audio.sample_rate,
|
179 |
+
subtype="PCM_16",
|
180 |
+
)
|
181 |
+
|
182 |
+
logger.info(f"The text is '{args.text}'")
|
183 |
+
logger.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
184 |
+
logger.info(f"Audio duration in seconds: {audio_duration:.3f}")
|
185 |
+
logger.info(
|
186 |
+
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
|
187 |
+
|
188 |
+
|
189 |
+
async def run_record(buf: asyncio.Queue[list[float]]):
|
190 |
+
loop = asyncio.get_event_loop()
|
191 |
+
|
192 |
+
def on_input(in_data, frame_count, time_info, status):
|
193 |
+
samples = [
|
194 |
+
v/32768.0 for v in list(struct.unpack('<' + 'h' * frame_count, in_data))]
|
195 |
+
loop.create_task(buf.put(samples))
|
196 |
+
return (None, pyaudio.paContinue)
|
197 |
+
|
198 |
+
frame_size = 320
|
199 |
+
recorder = pactx.open(format=pyaudio.paInt16, channels=1,
|
200 |
+
rate=SAMPLE_RATE, input=True,
|
201 |
+
frames_per_buffer=frame_size,
|
202 |
+
stream_callback=on_input)
|
203 |
+
recorder.start_stream()
|
204 |
+
logger.info('Start recording')
|
205 |
+
|
206 |
+
while recorder.is_active():
|
207 |
+
await asyncio.sleep(0.1)
|
208 |
+
|
209 |
+
|
210 |
+
async def main():
|
211 |
+
parser = argparse.ArgumentParser()
|
212 |
+
parser.add_argument('--provider', default='cpu',
|
213 |
+
help='onnxruntime provider, default is cpu, use cuda for GPU')
|
214 |
+
|
215 |
+
subparsers = parser.add_subparsers(help='commands help')
|
216 |
+
|
217 |
+
asr_parser = subparsers.add_parser('asr', help='run asr mode')
|
218 |
+
asr_parser.add_argument('--model', default='zipformer',
|
219 |
+
help='model name, default is zipformer')
|
220 |
+
asr_parser.add_argument('--lang', default='zh',
|
221 |
+
help='language, default is zh')
|
222 |
+
asr_parser.set_defaults(func=handle_asr)
|
223 |
+
|
224 |
+
tts_parser = subparsers.add_parser('tts', help='run tts mode')
|
225 |
+
tts_parser.add_argument('--sid', type=int, default=0, help="""Speaker ID. Used only for multi-speaker models, e.g.
|
226 |
+
models trained using the VCTK dataset. Not used for single-speaker
|
227 |
+
models, e.g., models trained using the LJ speech dataset.
|
228 |
+
""")
|
229 |
+
tts_parser.add_argument('--output', type=str, default='output.wav',
|
230 |
+
help='output file name, default is output.wav')
|
231 |
+
tts_parser.add_argument(
|
232 |
+
"--speed",
|
233 |
+
type=float,
|
234 |
+
default=1.0,
|
235 |
+
help="Speech speed. Larger->faster; smaller->slower",
|
236 |
+
)
|
237 |
+
tts_parser.add_argument(
|
238 |
+
"--max-num-sentences",
|
239 |
+
type=int,
|
240 |
+
default=2,
|
241 |
+
help="""Max number of sentences in a batch to avoid OOM if the input
|
242 |
+
text is very long. Set it to -1 to process all the sentences in a
|
243 |
+
single batch. A smaller value does not mean it is slower compared
|
244 |
+
to a larger one on CPU.
|
245 |
+
""",
|
246 |
+
)
|
247 |
+
tts_parser.add_argument(
|
248 |
+
"text",
|
249 |
+
type=str,
|
250 |
+
help="The input text to generate audio for",
|
251 |
+
)
|
252 |
+
tts_parser.set_defaults(func=handle_tts)
|
253 |
+
|
254 |
+
args = parser.parse_args()
|
255 |
+
|
256 |
+
if hasattr(args, 'func'):
|
257 |
+
await args.func(args)
|
258 |
+
else:
|
259 |
+
parser.print_help()
|
260 |
+
|
261 |
+
if __name__ == '__main__':
|
262 |
+
logging.basicConfig(
|
263 |
+
format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s')
|
264 |
+
logging.getLogger().setLevel(logging.INFO)
|
265 |
+
painfo = pactx.get_default_input_device_info()
|
266 |
+
assert painfo['maxInputChannels'] >= 1, 'No input device'
|
267 |
+
logger.info('Default input device: %s', painfo['name'])
|
268 |
+
|
269 |
+
for d in ['.', '..', '../..']:
|
270 |
+
if os.path.isdir(f'{d}/models'):
|
271 |
+
models_root = f'{d}/models'
|
272 |
+
break
|
273 |
+
assert models_root, 'Could not find models directory'
|
274 |
+
asyncio.run(main())
|
speaking.svg
ADDED
tts.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import sherpa_onnx
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
import asyncio
|
8 |
+
import time
|
9 |
+
import soundfile
|
10 |
+
from scipy.signal import resample
|
11 |
+
import io
|
12 |
+
import re
|
13 |
+
|
14 |
+
logger = logging.getLogger(__file__)
|
15 |
+
|
16 |
+
splitter = re.compile(r'[,,。.!?!?;;、\n]')
|
17 |
+
_tts_engines = {}
|
18 |
+
|
19 |
+
tts_configs = {
|
20 |
+
'vits-zh-hf-theresa': {
|
21 |
+
'model': 'theresa.onnx',
|
22 |
+
'lexicon': 'lexicon.txt',
|
23 |
+
'dict_dir': 'dict',
|
24 |
+
'tokens': 'tokens.txt',
|
25 |
+
'sample_rate': 22050,
|
26 |
+
# 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
|
27 |
+
},
|
28 |
+
'vits-melo-tts-zh_en': {
|
29 |
+
'model': 'model.onnx',
|
30 |
+
'lexicon': 'lexicon.txt',
|
31 |
+
'dict_dir': 'dict',
|
32 |
+
'tokens': 'tokens.txt',
|
33 |
+
'sample_rate': 44100,
|
34 |
+
'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
|
35 |
+
},
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig:
|
40 |
+
cfg = tts_configs[name]
|
41 |
+
fsts = []
|
42 |
+
model_dir = os.path.join(model_root, name)
|
43 |
+
for f in cfg.get('rule_fsts', ''):
|
44 |
+
fsts.append(os.path.join(model_dir, f))
|
45 |
+
tts_rule_fsts = ','.join(fsts) if fsts else ''
|
46 |
+
|
47 |
+
model_config = sherpa_onnx.OfflineTtsModelConfig(
|
48 |
+
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
|
49 |
+
model=os.path.join(model_dir, cfg['model']),
|
50 |
+
lexicon=os.path.join(model_dir, cfg['lexicon']),
|
51 |
+
dict_dir=os.path.join(model_dir, cfg['dict_dir']),
|
52 |
+
tokens=os.path.join(model_dir, cfg['tokens']),
|
53 |
+
),
|
54 |
+
provider=provider,
|
55 |
+
debug=0,
|
56 |
+
num_threads=num_threads,
|
57 |
+
)
|
58 |
+
tts_config = sherpa_onnx.OfflineTtsConfig(
|
59 |
+
model=model_config,
|
60 |
+
rule_fsts=tts_rule_fsts,
|
61 |
+
max_num_sentences=max_num_sentences)
|
62 |
+
|
63 |
+
if not tts_config.validate():
|
64 |
+
raise ValueError("tts: invalid config")
|
65 |
+
|
66 |
+
return tts_config
|
67 |
+
|
68 |
+
|
69 |
+
def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]:
|
70 |
+
sample_rate = tts_configs[args.tts_model]['sample_rate']
|
71 |
+
cache_engine = _tts_engines.get(args.tts_model)
|
72 |
+
if cache_engine:
|
73 |
+
return cache_engine, sample_rate
|
74 |
+
st = time.time()
|
75 |
+
tts_config = load_tts_model(
|
76 |
+
args.tts_model, args.models_root, args.tts_provider)
|
77 |
+
|
78 |
+
cache_engine = sherpa_onnx.OfflineTts(tts_config)
|
79 |
+
elapsed = time.time() - st
|
80 |
+
logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s")
|
81 |
+
_tts_engines[args.tts_model] = cache_engine
|
82 |
+
|
83 |
+
return cache_engine, sample_rate
|
84 |
+
|
85 |
+
|
86 |
+
class TTSResult:
|
87 |
+
def __init__(self, pcm_bytes: bytes, finished: bool):
|
88 |
+
self.pcm_bytes = pcm_bytes
|
89 |
+
self.finished = finished
|
90 |
+
self.progress: float = 0.0
|
91 |
+
self.elapsed: float = 0.0
|
92 |
+
self.audio_duration: float = 0.0
|
93 |
+
self.audio_size: int = 0
|
94 |
+
|
95 |
+
def to_dict(self):
|
96 |
+
return {
|
97 |
+
"progress": self.progress,
|
98 |
+
"elapsed": f'{int(self.elapsed * 1000)}ms',
|
99 |
+
"duration": f'{self.audio_duration:.2f}s',
|
100 |
+
"size": self.audio_size
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
class TTSStream:
|
105 |
+
def __init__(self, engine, sid: int, speed: float = 1.0, sample_rate: int = 16000, original_sample_rate: int = 16000):
|
106 |
+
self.engine = engine
|
107 |
+
self.sid = sid
|
108 |
+
self.speed = speed
|
109 |
+
self.outbuf: asyncio.Queue[TTSResult | None] = asyncio.Queue()
|
110 |
+
self.is_closed = False
|
111 |
+
self.target_sample_rate = sample_rate
|
112 |
+
self.original_sample_rate = original_sample_rate
|
113 |
+
|
114 |
+
def on_process(self, chunk: np.ndarray, progress: float):
|
115 |
+
if self.is_closed:
|
116 |
+
return 0
|
117 |
+
|
118 |
+
# resample to target sample rate
|
119 |
+
if self.target_sample_rate != self.original_sample_rate:
|
120 |
+
num_samples = int(
|
121 |
+
len(chunk) * self.target_sample_rate / self.original_sample_rate)
|
122 |
+
resampled_chunk = resample(chunk, num_samples)
|
123 |
+
chunk = resampled_chunk.astype(np.float32)
|
124 |
+
|
125 |
+
scaled_chunk = chunk * 32768.0
|
126 |
+
clipped_chunk = np.clip(scaled_chunk, -32768, 32767)
|
127 |
+
int16_chunk = clipped_chunk.astype(np.int16)
|
128 |
+
samples = int16_chunk.tobytes()
|
129 |
+
self.outbuf.put_nowait(TTSResult(samples, False))
|
130 |
+
return self.is_closed and 0 or 1
|
131 |
+
|
132 |
+
async def write(self, text: str, split: bool, pause: float = 0.2):
|
133 |
+
start = time.time()
|
134 |
+
if split:
|
135 |
+
texts = re.split(splitter, text)
|
136 |
+
else:
|
137 |
+
texts = [text]
|
138 |
+
|
139 |
+
audio_duration = 0.0
|
140 |
+
audio_size = 0
|
141 |
+
|
142 |
+
for idx, text in enumerate(texts):
|
143 |
+
text = text.strip()
|
144 |
+
if not text:
|
145 |
+
continue
|
146 |
+
sub_start = time.time()
|
147 |
+
|
148 |
+
audio = await asyncio.to_thread(self.engine.generate,
|
149 |
+
text, self.sid, self.speed,
|
150 |
+
self.on_process)
|
151 |
+
|
152 |
+
if not audio or not audio.sample_rate or not audio.samples:
|
153 |
+
logger.error(f"tts: failed to generate audio for "
|
154 |
+
f"'{text}' (audio={audio})")
|
155 |
+
continue
|
156 |
+
|
157 |
+
if split and idx < len(texts) - 1: # add a pause between sentences
|
158 |
+
noise = np.zeros(int(audio.sample_rate * pause))
|
159 |
+
self.on_process(noise, 1.0)
|
160 |
+
audio.samples = np.concatenate([audio.samples, noise])
|
161 |
+
|
162 |
+
audio_duration += len(audio.samples) / audio.sample_rate
|
163 |
+
audio_size += len(audio.samples)
|
164 |
+
elapsed_seconds = time.time() - sub_start
|
165 |
+
logger.info(f"tts: generated audio for '{text}', "
|
166 |
+
f"audio duration: {audio_duration:.2f}s, "
|
167 |
+
f"elapsed: {elapsed_seconds:.2f}s")
|
168 |
+
|
169 |
+
elapsed_seconds = time.time() - start
|
170 |
+
logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
|
171 |
+
f"audio duration: {audio_duration:.2f}s")
|
172 |
+
|
173 |
+
r = TTSResult(None, True)
|
174 |
+
r.elapsed = elapsed_seconds
|
175 |
+
r.audio_duration = audio_duration
|
176 |
+
r.progress = 1.0
|
177 |
+
r.finished = True
|
178 |
+
await self.outbuf.put(r)
|
179 |
+
|
180 |
+
async def close(self):
|
181 |
+
self.is_closed = True
|
182 |
+
self.outbuf.put_nowait(None)
|
183 |
+
logger.info("tts: stream closed")
|
184 |
+
|
185 |
+
async def read(self) -> TTSResult:
|
186 |
+
return await self.outbuf.get()
|
187 |
+
|
188 |
+
async def generate(self, text: str) -> io.BytesIO:
|
189 |
+
start = time.time()
|
190 |
+
audio = await asyncio.to_thread(self.engine.generate,
|
191 |
+
text, self.sid, self.speed)
|
192 |
+
elapsed_seconds = time.time() - start
|
193 |
+
audio_duration = len(audio.samples) / audio.sample_rate
|
194 |
+
|
195 |
+
logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
|
196 |
+
f"audio duration: {audio_duration:.2f}s, "
|
197 |
+
f"sample rate: {audio.sample_rate}")
|
198 |
+
|
199 |
+
if self.target_sample_rate != audio.sample_rate:
|
200 |
+
audio.samples = resample(audio.samples,
|
201 |
+
int(len(audio.samples) * self.target_sample_rate / audio.sample_rate))
|
202 |
+
audio.sample_rate = self.target_sample_rate
|
203 |
+
|
204 |
+
output = io.BytesIO()
|
205 |
+
soundfile.write(output,
|
206 |
+
audio.samples,
|
207 |
+
samplerate=audio.sample_rate,
|
208 |
+
subtype="PCM_16",
|
209 |
+
format="WAV")
|
210 |
+
output.seek(0)
|
211 |
+
return output
|
212 |
+
|
213 |
+
|
214 |
+
async def start_tts_stream(sid: int, sample_rate: int, speed: float, args) -> TTSStream:
|
215 |
+
engine, original_sample_rate = get_tts_engine(args)
|
216 |
+
return TTSStream(engine, sid, speed, sample_rate, original_sample_rate)
|
voice.png
ADDED