cuio commited on
Commit
da8e0c5
·
verified ·
1 Parent(s): 0cf4ed8

Upload 11 files

Browse files
Files changed (11) hide show
  1. app.js +99 -0
  2. app.py +194 -0
  3. asr.py +233 -0
  4. audio_process.js +45 -0
  5. index.html +179 -0
  6. record.svg +1 -0
  7. requirements.txt +7 -0
  8. sherpa_examples.py +274 -0
  9. speaking.svg +1 -0
  10. tts.py +216 -0
  11. voice.png +0 -0
app.js ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const demoapp = {
2
+ text: '讲个冷笑话吧,要很好笑的那种。',
3
+ recording: false,
4
+ asrWS: null,
5
+ currentText: null,
6
+ disabled: false,
7
+ elapsedTime: null,
8
+ logs: [{ idx: 0, text: 'Happily here at ruzhila.cn.' }],
9
+ async init() {
10
+ },
11
+ async dotts() {
12
+ let audioContext = new AudioContext({ sampleRate: 16000 })
13
+ await audioContext.audioWorklet.addModule('./audio_process.js')
14
+
15
+ const ws = new WebSocket('/tts');
16
+ ws.onopen = () => {
17
+ ws.send(this.text);
18
+ };
19
+ const playNode = new AudioWorkletNode(audioContext, 'play-audio-processor');
20
+ playNode.connect(audioContext.destination);
21
+
22
+ this.disabled = true;
23
+ ws.onmessage = async (e) => {
24
+ if (e.data instanceof Blob) {
25
+ e.data.arrayBuffer().then((arrayBuffer) => {
26
+ const int16Array = new Int16Array(arrayBuffer);
27
+ let float32Array = new Float32Array(int16Array.length);
28
+ for (let i = 0; i < int16Array.length; i++) {
29
+ float32Array[i] = int16Array[i] / 32768.;
30
+ }
31
+ playNode.port.postMessage({ message: 'audioData', audioData: float32Array });
32
+ });
33
+ } else {
34
+ this.elapsedTime = JSON.parse(e.data)?.elapsed;
35
+ this.disabled = false;
36
+ }
37
+ }
38
+ },
39
+
40
+ async stopasr() {
41
+ if (!this.asrWS) {
42
+ return;
43
+ }
44
+ this.asrWS.close();
45
+ this.asrWS = null;
46
+ this.recording = false;
47
+ if (this.currentText) {
48
+ this.logs.push({ idx: this.logs.length + 1, text: this.currentText });
49
+ }
50
+ this.currentText = null;
51
+
52
+ },
53
+
54
+ async doasr() {
55
+ const audioConstraints = {
56
+ video: false,
57
+ audio: true,
58
+ };
59
+
60
+ const mediaStream = await navigator.mediaDevices.getUserMedia(audioConstraints);
61
+
62
+ const ws = new WebSocket('/asr');
63
+ let currentMessage = '';
64
+
65
+ ws.onopen = () => {
66
+ this.logs = [];
67
+ };
68
+
69
+ ws.onmessage = (e) => {
70
+ const data = JSON.parse(e.data);
71
+ const { text, finished, idx } = data;
72
+
73
+ currentMessage = text;
74
+ this.currentText = text
75
+
76
+ if (finished) {
77
+ this.logs.push({ text: currentMessage, idx: idx });
78
+ currentMessage = '';
79
+ this.currentText = null
80
+ }
81
+ };
82
+
83
+ let audioContext = new AudioContext({ sampleRate: 16000 })
84
+ await audioContext.audioWorklet.addModule('./audio_process.js')
85
+
86
+ const recordNode = new AudioWorkletNode(audioContext, 'record-audio-processor');
87
+ recordNode.connect(audioContext.destination);
88
+ recordNode.port.onmessage = (event) => {
89
+ if (ws && ws.readyState === WebSocket.OPEN) {
90
+ const int16Array = event.data.data;
91
+ ws.send(int16Array.buffer);
92
+ }
93
+ }
94
+ const source = audioContext.createMediaStreamSource(mediaStream);
95
+ source.connect(recordNode);
96
+ this.asrWS = ws;
97
+ this.recording = true;
98
+ }
99
+ }
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query
3
+ from fastapi.responses import HTMLResponse, StreamingResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ import asyncio
6
+ import logging
7
+ from pydantic import BaseModel, Field
8
+ import uvicorn
9
+ from voiceapi.tts import TTSResult, start_tts_stream, TTSStream
10
+ from voiceapi.asr import start_asr_stream, ASRStream, ASRResult
11
+ import logging
12
+ import argparse
13
+ import os
14
+
15
+ app = FastAPI()
16
+ logger = logging.getLogger(__file__)
17
+
18
+
19
+ @app.websocket("/asr")
20
+ async def websocket_asr(websocket: WebSocket,
21
+ samplerate: int = Query(16000, title="Sample Rate",
22
+ description="The sample rate of the audio."),):
23
+ await websocket.accept()
24
+
25
+ asr_stream: ASRStream = await start_asr_stream(samplerate, args)
26
+ if not asr_stream:
27
+ logger.error("failed to start ASR stream")
28
+ await websocket.close()
29
+ return
30
+
31
+ async def task_recv_pcm():
32
+ while True:
33
+ pcm_bytes = await websocket.receive_bytes()
34
+ if not pcm_bytes:
35
+ return
36
+ await asr_stream.write(pcm_bytes)
37
+
38
+ async def task_send_result():
39
+ while True:
40
+ result: ASRResult = await asr_stream.read()
41
+ if not result:
42
+ return
43
+ await websocket.send_json(result.to_dict())
44
+ try:
45
+ await asyncio.gather(task_recv_pcm(), task_send_result())
46
+ except WebSocketDisconnect:
47
+ logger.info("asr: disconnected")
48
+ finally:
49
+ await asr_stream.close()
50
+
51
+
52
+ @app.websocket("/tts")
53
+ async def websocket_tts(websocket: WebSocket,
54
+ samplerate: int = Query(16000,
55
+ title="Sample Rate",
56
+ description="The sample rate of the generated audio."),
57
+ interrupt: bool = Query(True,
58
+ title="Interrupt",
59
+ description="Interrupt the current TTS stream when a new text is received."),
60
+ sid: int = Query(0,
61
+ title="Speaker ID",
62
+ description="The ID of the speaker to use for TTS."),
63
+ chunk_size: int = Query(1024,
64
+ title="Chunk Size",
65
+ description="The size of the chunk to send to the client."),
66
+ speed: float = Query(1.0,
67
+ title="Speed",
68
+ description="The speed of the generated audio."),
69
+ split: bool = Query(True,
70
+ title="Split",
71
+ description="Split the text into sentences.")):
72
+
73
+ await websocket.accept()
74
+ tts_stream: TTSStream = None
75
+
76
+ async def task_recv_text():
77
+ nonlocal tts_stream
78
+ while True:
79
+ text = await websocket.receive_text()
80
+ if not text:
81
+ return
82
+
83
+ if interrupt or not tts_stream:
84
+ if tts_stream:
85
+ await tts_stream.close()
86
+ logger.info("tts: stream interrupt")
87
+
88
+ tts_stream = await start_tts_stream(sid, samplerate, speed, args)
89
+ if not tts_stream:
90
+ logger.error("tts: failed to allocate tts stream")
91
+ await websocket.close()
92
+ return
93
+ logger.info(f"tts: received: {text} (split={split})")
94
+ await tts_stream.write(text, split)
95
+
96
+ async def task_send_pcm():
97
+ nonlocal tts_stream
98
+ while not tts_stream:
99
+ # wait for tts stream to be created
100
+ await asyncio.sleep(0.1)
101
+
102
+ while True:
103
+ result: TTSResult = await tts_stream.read()
104
+ if not result:
105
+ return
106
+
107
+ if result.finished:
108
+ await websocket.send_json(result.to_dict())
109
+ else:
110
+ for i in range(0, len(result.pcm_bytes), chunk_size):
111
+ await websocket.send_bytes(result.pcm_bytes[i:i+chunk_size])
112
+
113
+ try:
114
+ await asyncio.gather(task_recv_text(), task_send_pcm())
115
+ except WebSocketDisconnect:
116
+ logger.info("tts: disconnected")
117
+ finally:
118
+ if tts_stream:
119
+ await tts_stream.close()
120
+
121
+
122
+ class TTSRequest(BaseModel):
123
+ text: str = Field(..., title="Text",
124
+ description="The text to be converted to speech.",
125
+ examples=["Hello, world!"])
126
+ sid: int = Field(0, title="Speaker ID",
127
+ description="The ID of the speaker to use for TTS.")
128
+ samplerate: int = Field(16000, title="Sample Rate",
129
+ description="The sample rate of the generated audio.")
130
+ speed: float = Field(1.0, title="Speed",
131
+ description="The speed of the generated audio.")
132
+
133
+
134
+ @ app.post("/tts",
135
+ description="Generate speech audio from text.",
136
+ response_class=StreamingResponse, responses={200: {"content": {"audio/wav": {}}}})
137
+ async def tts_generate(req: TTSRequest):
138
+ if not req.text:
139
+ raise HTTPException(status_code=400, detail="text is required")
140
+
141
+ tts_stream = await start_tts_stream(req.sid, req.samplerate, req.speed, args)
142
+ if not tts_stream:
143
+ raise HTTPException(
144
+ status_code=500, detail="failed to start TTS stream")
145
+
146
+ r = await tts_stream.generate(req.text)
147
+ return StreamingResponse(r, media_type="audio/wav")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ models_root = './models'
152
+
153
+ for d in ['.', '..', '../..']:
154
+ if os.path.isdir(f'{d}/models'):
155
+ models_root = f'{d}/models'
156
+ break
157
+
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument("--port", type=int, default=8000, help="port number")
160
+ parser.add_argument("--addr", type=str,
161
+ default="0.0.0.0", help="serve address")
162
+
163
+ parser.add_argument("--asr-provider", type=str,
164
+ default="cpu", help="asr provider, cpu or cuda")
165
+ parser.add_argument("--tts-provider", type=str,
166
+ default="cpu", help="tts provider, cpu or cuda")
167
+
168
+ parser.add_argument("--threads", type=int, default=2,
169
+ help="number of threads")
170
+
171
+ parser.add_argument("--models-root", type=str, default=models_root,
172
+ help="model root directory")
173
+
174
+ parser.add_argument("--asr-model", type=str, default='sensevoice',
175
+ help="ASR model name: zipformer-bilingual, sensevoice, paraformer-trilingual, paraformer-en")
176
+
177
+ parser.add_argument("--asr-lang", type=str, default='zh',
178
+ help="ASR language, zh, en, ja, ko, yue")
179
+
180
+ parser.add_argument("--tts-model", type=str, default='vits-zh-hf-theresa',
181
+ help="TTS model name: vits-zh-hf-theresa, vits-melo-tts-zh_en")
182
+
183
+ args = parser.parse_args()
184
+
185
+ if args.tts_model == 'vits-melo-tts-zh_en' and args.tts_provider == 'cuda':
186
+ logger.warning(
187
+ "vits-melo-tts-zh_en does not support CUDA fallback to CPU")
188
+ args.tts_provider = 'cpu'
189
+
190
+ app.mount("/", app=StaticFiles(directory="./assets", html=True), name="assets")
191
+
192
+ logging.basicConfig(format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s',
193
+ level=logging.INFO)
194
+ uvicorn.run(app, host=args.addr, port=args.port)
asr.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import logging
3
+ import time
4
+ import logging
5
+ import sherpa_onnx
6
+ import os
7
+ import asyncio
8
+ import numpy as np
9
+
10
+ logger = logging.getLogger(__file__)
11
+ _asr_engines = {}
12
+
13
+
14
+ class ASRResult:
15
+ def __init__(self, text: str, finished: bool, idx: int):
16
+ self.text = text
17
+ self.finished = finished
18
+ self.idx = idx
19
+
20
+ def to_dict(self):
21
+ return {"text": self.text, "finished": self.finished, "idx": self.idx}
22
+
23
+
24
+ class ASRStream:
25
+ def __init__(self, recognizer: Union[sherpa_onnx.OnlineRecognizer | sherpa_onnx.OfflineRecognizer], sample_rate: int) -> None:
26
+ self.recognizer = recognizer
27
+ self.inbuf = asyncio.Queue()
28
+ self.outbuf = asyncio.Queue()
29
+ self.sample_rate = sample_rate
30
+ self.is_closed = False
31
+ self.online = isinstance(recognizer, sherpa_onnx.OnlineRecognizer)
32
+
33
+ async def start(self):
34
+ if self.online:
35
+ asyncio.create_task(self.run_online())
36
+ else:
37
+ asyncio.create_task(self.run_offline())
38
+
39
+ async def run_online(self):
40
+ stream = self.recognizer.create_stream()
41
+ last_result = ""
42
+ segment_id = 0
43
+ logger.info('asr: start real-time recognizer')
44
+ while not self.is_closed:
45
+ samples = await self.inbuf.get()
46
+ stream.accept_waveform(self.sample_rate, samples)
47
+ while self.recognizer.is_ready(stream):
48
+ self.recognizer.decode_stream(stream)
49
+
50
+ is_endpoint = self.recognizer.is_endpoint(stream)
51
+ result = self.recognizer.get_result(stream)
52
+
53
+ if result and (last_result != result):
54
+ last_result = result
55
+ logger.info(f' > {segment_id}:{result}')
56
+ self.outbuf.put_nowait(
57
+ ASRResult(result, False, segment_id))
58
+
59
+ if is_endpoint:
60
+ if result:
61
+ logger.info(f'{segment_id}: {result}')
62
+ self.outbuf.put_nowait(
63
+ ASRResult(result, True, segment_id))
64
+ segment_id += 1
65
+ self.recognizer.reset(stream)
66
+
67
+ async def run_offline(self):
68
+ vad = _asr_engines['vad']
69
+ segment_id = 0
70
+ st = None
71
+ while not self.is_closed:
72
+ samples = await self.inbuf.get()
73
+ vad.accept_waveform(samples)
74
+ while not vad.empty():
75
+ if not st:
76
+ st = time.time()
77
+ stream = self.recognizer.create_stream()
78
+ stream.accept_waveform(self.sample_rate, vad.front.samples)
79
+
80
+ vad.pop()
81
+ self.recognizer.decode_stream(stream)
82
+
83
+ result = stream.result.text.strip()
84
+ if result:
85
+ duration = time.time() - st
86
+ logger.info(f'{segment_id}:{result} ({duration:.2f}s)')
87
+ self.outbuf.put_nowait(ASRResult(result, True, segment_id))
88
+ segment_id += 1
89
+ st = None
90
+
91
+ async def close(self):
92
+ self.is_closed = True
93
+ self.outbuf.put_nowait(None)
94
+
95
+ async def write(self, pcm_bytes: bytes):
96
+ pcm_data = np.frombuffer(pcm_bytes, dtype=np.int16)
97
+ samples = pcm_data.astype(np.float32) / 32768.0
98
+ self.inbuf.put_nowait(samples)
99
+
100
+ async def read(self) -> ASRResult:
101
+ return await self.outbuf.get()
102
+
103
+
104
+ def create_zipformer(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
105
+ d = os.path.join(
106
+ args.models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20')
107
+ if not os.path.exists(d):
108
+ raise ValueError(f"asr: model not found {d}")
109
+
110
+ encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx")
111
+ decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx")
112
+ joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx")
113
+ tokens = os.path.join(d, "tokens.txt")
114
+
115
+ recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
116
+ tokens=tokens,
117
+ encoder=encoder,
118
+ decoder=decoder,
119
+ joiner=joiner,
120
+ provider=args.asr_provider,
121
+ num_threads=args.threads,
122
+ sample_rate=samplerate,
123
+ feature_dim=80,
124
+ enable_endpoint_detection=True,
125
+ rule1_min_trailing_silence=2.4,
126
+ rule2_min_trailing_silence=1.2,
127
+ rule3_min_utterance_length=20, # it essentially disables this rule
128
+ )
129
+ return recognizer
130
+
131
+
132
+ def create_sensevoice(samplerate: int, args) -> sherpa_onnx.OfflineRecognizer:
133
+ d = os.path.join(args.models_root,
134
+ 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17')
135
+
136
+ if not os.path.exists(d):
137
+ raise ValueError(f"asr: model not found {d}")
138
+
139
+ recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
140
+ model=os.path.join(d, 'model.onnx'),
141
+ tokens=os.path.join(d, 'tokens.txt'),
142
+ num_threads=args.threads,
143
+ sample_rate=samplerate,
144
+ use_itn=True,
145
+ debug=0,
146
+ language=args.asr_lang,
147
+ )
148
+ return recognizer
149
+
150
+
151
+ def create_paraformer_trilingual(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
152
+ d = os.path.join(
153
+ args.models_root, 'sherpa-onnx-paraformer-trilingual-zh-cantonese-en')
154
+ if not os.path.exists(d):
155
+ raise ValueError(f"asr: model not found {d}")
156
+
157
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
158
+ paraformer=os.path.join(d, 'model.onnx'),
159
+ tokens=os.path.join(d, 'tokens.txt'),
160
+ num_threads=args.threads,
161
+ sample_rate=samplerate,
162
+ debug=0,
163
+ provider=args.asr_provider,
164
+ )
165
+ return recognizer
166
+
167
+
168
+ def create_paraformer_en(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
169
+ d = os.path.join(
170
+ args.models_root, 'sherpa-onnx-paraformer-en')
171
+ if not os.path.exists(d):
172
+ raise ValueError(f"asr: model not found {d}")
173
+
174
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
175
+ paraformer=os.path.join(d, 'model.onnx'),
176
+ tokens=os.path.join(d, 'tokens.txt'),
177
+ num_threads=args.threads,
178
+ sample_rate=samplerate,
179
+ use_itn=True,
180
+ debug=0,
181
+ provider=args.asr_provider,
182
+ )
183
+ return recognizer
184
+
185
+
186
+ def load_asr_engine(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
187
+ cache_engine = _asr_engines.get(args.asr_model)
188
+ if cache_engine:
189
+ return cache_engine
190
+ st = time.time()
191
+ if args.asr_model == 'zipformer-bilingual':
192
+ cache_engine = create_zipformer(samplerate, args)
193
+ elif args.asr_model == 'sensevoice':
194
+ cache_engine = create_sensevoice(samplerate, args)
195
+ _asr_engines['vad'] = load_vad_engine(samplerate, args)
196
+ elif args.asr_model == 'paraformer-trilingual':
197
+ cache_engine = create_paraformer_trilingual(samplerate, args)
198
+ _asr_engines['vad'] = load_vad_engine(samplerate, args)
199
+ elif args.asr_model == 'paraformer-en':
200
+ cache_engine = create_paraformer_en(samplerate, args)
201
+ _asr_engines['vad'] = load_vad_engine(samplerate, args)
202
+ else:
203
+ raise ValueError(f"asr: unknown model {args.asr_model}")
204
+ _asr_engines[args.asr_model] = cache_engine
205
+ logger.info(f"asr: engine loaded in {time.time() - st:.2f}s")
206
+ return cache_engine
207
+
208
+
209
+ def load_vad_engine(samplerate: int, args, min_silence_duration: float = 0.25, buffer_size_in_seconds: int = 100) -> sherpa_onnx.VoiceActivityDetector:
210
+ config = sherpa_onnx.VadModelConfig()
211
+ d = os.path.join(args.models_root, 'silero_vad')
212
+ if not os.path.exists(d):
213
+ raise ValueError(f"vad: model not found {d}")
214
+
215
+ config.silero_vad.model = os.path.join(d, 'silero_vad.onnx')
216
+ config.silero_vad.min_silence_duration = min_silence_duration
217
+ config.sample_rate = samplerate
218
+ config.provider = args.asr_provider
219
+ config.num_threads = args.threads
220
+
221
+ vad = sherpa_onnx.VoiceActivityDetector(
222
+ config,
223
+ buffer_size_in_seconds=buffer_size_in_seconds)
224
+ return vad
225
+
226
+
227
+ async def start_asr_stream(samplerate: int, args) -> ASRStream:
228
+ """
229
+ Start a ASR stream
230
+ """
231
+ stream = ASRStream(load_asr_engine(samplerate, args), samplerate)
232
+ await stream.start()
233
+ return stream
audio_process.js ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class PlayerAudioProcessor extends AudioWorkletProcessor {
2
+ constructor() {
3
+ super();
4
+ this.buffer = new Float32Array();
5
+ this.port.onmessage = (event) => {
6
+ let newFetchedData = new Float32Array(this.buffer.length + event.data.audioData.length);
7
+ newFetchedData.set(this.buffer, 0);
8
+ newFetchedData.set(event.data.audioData, this.buffer.length);
9
+ this.buffer = newFetchedData;
10
+ };
11
+ }
12
+
13
+ process(inputs, outputs, parameters) {
14
+ const output = outputs[0];
15
+ const channel = output[0];
16
+ const bufferLength = this.buffer.length;
17
+ for (let i = 0; i < channel.length; i++) {
18
+ channel[i] = (i < bufferLength) ? this.buffer[i] : 0;
19
+ }
20
+ this.buffer = this.buffer.slice(channel.length);
21
+ return true;
22
+ }
23
+ }
24
+
25
+ class RecordAudioProcessor extends AudioWorkletProcessor {
26
+ constructor() {
27
+ super();
28
+ }
29
+
30
+ process(inputs, outputs, parameters) {
31
+ const channel = inputs[0][0];
32
+ if (!channel || channel.length === 0) {
33
+ return true;
34
+ }
35
+ const int16Array = new Int16Array(channel.length);
36
+ for (let i = 0; i < channel.length; i++) {
37
+ int16Array[i] = channel[i] * 32767;
38
+ }
39
+ this.port.postMessage({ data: int16Array });
40
+ return true
41
+ }
42
+ }
43
+
44
+ registerProcessor('play-audio-processor', PlayerAudioProcessor);
45
+ registerProcessor('record-audio-processor', RecordAudioProcessor);
index.html ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <link rel="icon" type="image/svg+xml" href="./voice.png" />
8
+ <script src="//cdn.tailwindcss.com?plugins=forms"></script>
9
+ <link href="https://cdn.jsdelivr.net/npm/tailwindcss@latest/dist/tailwind.min.css" rel="stylesheet">
10
+ <script src="//cdn.jsdelivr.net/npm/[email protected]/dist/cdn.min.js" defer></script>
11
+ <script src="./app.js"></script>
12
+ <title>voiceapi demo </title>
13
+ <style>
14
+ * {
15
+ margin: 0;
16
+ padding: 0;
17
+ }
18
+ </style>
19
+
20
+ <style type="text/tailwindcss">
21
+ .label { @apply text-gray-900 w-[50px] lg:w-20 }
22
+ .title{
23
+ @apply text-[16px] text-zinc-500 mx-2;
24
+ }
25
+
26
+ .select { @apply w-full rounded-md h-10 }
27
+
28
+ .round { @apply rounded border px-3 p-2 border-slate-300 placeholder-gray-400 placeholder:text-sm
29
+ focus:bg-white focus:text-gray-900 focus:placeholder-gray-500 focus:outline-none
30
+ focus:border-zinc-950 focus:border ring-0 focus:ring-0 text-gray-900 }
31
+
32
+ .checkbox { @apply ml-2 lg:ml-4 border focus:outline-none ring-0 focus:ring-gray-800 text-gray-900 }
33
+ .dash{ @apply border border-dashed border-zinc-200 flex flex-grow }
34
+
35
+ .button { @apply hover:bg-opacity-90 text-white font-bold py-1.5 px-6 rounded-full cursor-pointer }
36
+ .card { @apply bg-white shadow-sm rounded-xl border p-4 }
37
+
38
+
39
+ .animate-ping {
40
+ animation: ping 2s cubic-bezier(0.5, 0.4, 0.2, 1) infinite;
41
+ }
42
+
43
+ @keyframes ping {
44
+ 0% {
45
+ transform: scale(1);
46
+ opacity: 1;
47
+ }
48
+ 50% {
49
+ transform: scale(1.2);
50
+ opacity: 0.7;
51
+ }
52
+ 100% {
53
+ transform: scale(1);
54
+ opacity: 1;
55
+ }
56
+ }
57
+ </style>
58
+ </head>
59
+
60
+ <body>
61
+ <script>
62
+ async function initAudioWorklet() {
63
+ try {
64
+ // Check for browser support
65
+ if (!('AudioContext' in window) || !('audioWorklet' in AudioContext.prototype)) {
66
+ console.error('Audio Worklet API is not supported in this browser.');
67
+ return;
68
+ }
69
+
70
+ // Initialize AudioContext
71
+ const audioContext = new AudioContext();
72
+
73
+ // Add Audio Worklet module
74
+ await audioContext.audioWorklet.addModule('./audio_process.js');
75
+
76
+ console.log('Audio Worklet module added successfully.');
77
+ // Your code to use the Audio Worklet goes here
78
+
79
+ } catch (error) {
80
+ console.error('Error initializing Audio Worklet:', error);
81
+ }
82
+ }
83
+
84
+ // Initialize Audio Worklet when the page is loaded
85
+ window.addEventListener('load', initAudioWorklet);
86
+ </script>
87
+ <div x-data="demoapp">
88
+ <header class="bg-gray-900 py-4 px-5 lg:p-4 lg:px-10 text-white sticky top-0 z-20">
89
+ <div class="flex w-full justify-between items-center">
90
+ <p class="gap-x-3">
91
+ <span>VoiceAPI Demo</span> /
92
+ <a href="https://ruzhila.cn/?from=voiceapi_demo">ruzhila.cn</a>
93
+ </p>
94
+ <a target="_blank" href="https://github.com/ruzhila/voiceapi" class="hover:cursor-pointer">
95
+ <svg t="1724996252746" class="icon" viewBox="0 0 1024 1024" version="1.1"
96
+ xmlns="http://www.w3.org/2000/svg" p-id="" width="25" height="25">
97
+ <path
98
+ d="M512 12.64c-282.752 0-512 229.216-512 512 0 226.208 146.72 418.144 350.144 485.824 25.6 4.736 35.008-11.104 35.008-24.64 0-12.192-0.48-52.544-0.704-95.328-142.464 30.976-172.512-60.416-172.512-60.416-23.296-59.168-56.832-74.912-56.832-74.912-46.464-31.776 3.52-31.136 3.52-31.136 51.392 3.616 78.464 52.768 78.464 52.768 45.664 78.272 119.776 55.648 148.992 42.56 4.576-33.088 17.856-55.68 32.512-68.48-113.728-12.928-233.28-56.864-233.28-253.024 0-55.904 20-101.568 52.768-137.44-5.312-12.896-22.848-64.96 4.96-135.488 0 0 43.008-13.76 140.832 52.48 40.832-11.36 84.64-17.024 128.16-17.248 43.488 0.192 87.328 5.888 128.256 17.248 97.728-66.24 140.64-52.48 140.64-52.48 27.872 70.528 10.336 122.592 5.024 135.488 32.832 35.84 52.704 81.536 52.704 137.44 0 196.64-119.776 239.936-233.792 252.64 18.368 15.904 34.72 47.04 34.72 94.816 0 68.512-0.608 123.648-0.608 140.512 0 13.632 9.216 29.6 35.168 24.576 203.328-67.776 349.856-259.616 349.856-485.76 0-282.784-229.248-512-512-512z"
99
+ fill="#ffffff"></path>
100
+ </svg>
101
+ </a>
102
+ </div>
103
+ </header>
104
+
105
+ <div class="flex px-6 gap-x-10 w-full max-w-7xl mx-auto">
106
+ <div class="relative flex flex-col items-center w-1/3 py-10">
107
+ <div class="w-full">
108
+ <textarea x-model="text" class="round p-4 w-full h-[36rem] text-sm"
109
+ placeholder="Enter text here"></textarea>
110
+ </div>
111
+
112
+ <div>
113
+ <button @click="dotts" :disabled="disabled"
114
+ class="button bg-gray-900 flex items-center gap-x-2 mt-6">
115
+ <span>Speak</span>
116
+ <svg t="1726215464577" class="icon" viewBox="0 0 1024 1024" version="1.1"
117
+ xmlns="http://www.w3.org/2000/svg" p-id="4263" width="20" height="20">
118
+ <path
119
+ d="M830.450526 853.759999q-11.722105 8.791579-27.351579 8.791579-19.536842 0-33.701053-14.164211t-14.164211-33.701053q0-21.490526
120
+ 16.606316-36.143158 0.976842-0.976842 1.953684-1.465263t1.953684-1.465263l0.976842-0.976842q27.351579-18.56 50.795789-43.957895t41.027368-55.191579 27.351579-63.494737 9.768421-69.84421q0-73.263158-37.12-133.827368t-92.8-99.637895q-20.513684-14.652632-20.513684-39.073684 0-19.536842 14.164211-33.701053t33.701053-14.164211q16.606316 0 29.305263 10.745263 36.143158 25.397895 67.402105 59.098947t53.726316 73.263158 35.166316 84.496842 12.698947 92.8q0 48.842105-12.698947 93.776842t-35.654737 84.985263-54.214737 73.751579-68.378947 59.098947zM775.747368 415.157894q20.513684 28.328421 32.72421 57.145263t12.210526 69.84421q0 39.073684-12.698947 70.332632t-32.235789 56.656842q-7.814737 10.745263-16.606316 19.048421t-22.467368 8.303158q-17.583158 0-29.793684-12.698947t-12.210526-30.282105q0-7.814737 2.930526-15.629474l-0.976842 0q4.884211-10.745263 11.722105-20.513684t13.187368-20.025263 10.745263-23.444211 4.395789-31.747368q0-17.583158-4.395789-30.770526t-10.745263-23.932632-13.187368-20.513684-10.745263-20.513684q-2.930526-6.837895-2.930526-15.629474 0-17.583158 12.210526-30.282105t29.793684-12.698947q13.675789 0 22.467368 8.303158t16.606316 19.048421zM460.227368 995.402104q-49.818947-44.934737-105.498947-93.776842t-103.545263-89.869474q-55.68-46.888421-111.36-92.8-10.745263 0.976842-21.490526 0.976842-8.791579 0.976842-18.56 0.976842l-16.606316 0q-26.374737 0-42.981053-16.117895t-16.606316-38.585263l0-246.16421 0.976842 0-0.976842-0.976842q0-27.351579 17.094737-44.934737t42.492632-17.583158l55.68 0q89.869474-76.193684 163.132631-136.757895 31.258947-26.374737 61.541053-51.28421t54.703158-45.423158 41.027368-34.189474 20.513684-16.606316q29.305263-21.490526 47.376842-19.536842t28.328421 17.583158 14.164211 38.096842 3.907368 41.027368l0 788.311578 0 2.930526q0 18.56-6.837895 39.562105t-21.002105 33.212632-35.654737 10.256842-49.818947-28.328421z"
121
+ p-id="4264" fill="#ffffff"></path>
122
+ </svg>
123
+ </button>
124
+ </div>
125
+ <template x-if="elapsedTime">
126
+ <p x-text="`elapsedTime: ${elapsedTime}`" class="mt-4 text-sm text-gray-600 "></p>
127
+ </template>
128
+ </div>
129
+
130
+ <!-- recording -->
131
+ <div class="w-full flex-grow h-[calc(100vh-10rem)] xl:pl-10 py-10">
132
+
133
+ <div
134
+ class="rounded border border-gray-500 p-3 w-full flex flex-col items-end h-[36rem] overflow-y-auto">
135
+ <template x-for="item in logs">
136
+ <div class="mt-3 mb-2">
137
+ <span
138
+ class="text-white px-4 py-1.5 text-[13px] display-inline-block border border-gray-900 rounded-t-full rounded-l-full bg-gray-900 justify-end w-auto"
139
+ x-text="item?.text">
140
+ </span>
141
+ </div>
142
+ </template>
143
+ </div>
144
+
145
+
146
+ <template x-if="currentText">
147
+ <p x-text="`${currentText} …`" class="text-gray-800 mt-4 text-sm text-center"></p>
148
+ </template>
149
+
150
+ <template x-if="!recording">
151
+ <div class="flex flex-col gap-y-4 items-center justify-center mt-4">
152
+ <p @click="doasr"
153
+ class="mt-2 border border-gray-100 rounded-full duration-300 hover:scale-105 hover:border-gray-400">
154
+ <img src="./images/record.svg" alt="" class="w-14 h-14 mx-auto cursor-pointer">
155
+ </p>
156
+ <p class="text-gray-600">Click to record !</p>
157
+ </div>
158
+ </template>
159
+
160
+ <template x-if="recording">
161
+ <div class="flex flex-col items-center justify-center gap-y-4 mt-4">
162
+
163
+ <p @click="stopasr"
164
+ class="mt-2 border border-red-100 rounded-full duration-300 hover:scale-105 hover:border-red-400">
165
+ <img src="./images/speaking.svg" alt=""
166
+ class="w-14 h-14 mx-auto cursor-pointer animate-ping">
167
+ </p>
168
+ <div class="flex items-center text-gray-600 gap-x-4">
169
+ <p>Click to stop recording !</p>
170
+ </div>
171
+ </div>
172
+ </template>
173
+ </div>
174
+ </div>
175
+ </div>
176
+ </div>
177
+ </body>
178
+
179
+ </html>
record.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ sherpa-onnx == 1.10.24
2
+ soundfile == 0.12.1
3
+ fastapi == 0.114.1
4
+ uvicorn == 0.30.6
5
+ scipy == 1.13.1
6
+ numpy == 1.26.4
7
+ websockets == 13.0.1
sherpa_examples.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/env python3
2
+ """
3
+ Real-time ASR using microphone
4
+ """
5
+
6
+ import argparse
7
+ import logging
8
+ import sherpa_onnx
9
+ import os
10
+ import time
11
+ import struct
12
+ import asyncio
13
+ import soundfile
14
+
15
+ try:
16
+ import pyaudio
17
+ except ImportError:
18
+ raise ImportError('Please install pyaudio with `pip install pyaudio`')
19
+
20
+ logger = logging.getLogger(__name__)
21
+ SAMPLE_RATE = 16000
22
+
23
+ pactx = pyaudio.PyAudio()
24
+ models_root: str = None
25
+ num_threads: int = 1
26
+
27
+
28
+ def create_zipformer(args) -> sherpa_onnx.OnlineRecognizer:
29
+ d = os.path.join(
30
+ models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20')
31
+ encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx")
32
+ decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx")
33
+ joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx")
34
+ tokens = os.path.join(d, "tokens.txt")
35
+
36
+ recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
37
+ tokens=tokens,
38
+ encoder=encoder,
39
+ decoder=decoder,
40
+ joiner=joiner,
41
+ provider=args.provider,
42
+ num_threads=num_threads,
43
+ sample_rate=SAMPLE_RATE,
44
+ feature_dim=80,
45
+ enable_endpoint_detection=True,
46
+ rule1_min_trailing_silence=2.4,
47
+ rule2_min_trailing_silence=1.2,
48
+ rule3_min_utterance_length=20, # it essentially disables this rule
49
+ )
50
+ return recognizer
51
+
52
+
53
+ def create_sensevoice(args) -> sherpa_onnx.OfflineRecognizer:
54
+ model = os.path.join(
55
+ models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'model.onnx')
56
+ tokens = os.path.join(
57
+ models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'tokens.txt')
58
+ recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
59
+ model=model,
60
+ tokens=tokens,
61
+ num_threads=num_threads,
62
+ use_itn=True,
63
+ debug=0,
64
+ language=args.lang,
65
+ )
66
+ return recognizer
67
+
68
+
69
+ async def run_online(buf, recognizer):
70
+ stream = recognizer.create_stream()
71
+ last_result = ""
72
+ segment_id = 0
73
+ logger.info('Start real-time recognizer')
74
+ while True:
75
+ samples = await buf.get()
76
+ stream.accept_waveform(SAMPLE_RATE, samples)
77
+ while recognizer.is_ready(stream):
78
+ recognizer.decode_stream(stream)
79
+
80
+ is_endpoint = recognizer.is_endpoint(stream)
81
+ result = recognizer.get_result(stream)
82
+
83
+ if result and (last_result != result):
84
+ last_result = result
85
+ logger.info(f' > {segment_id}:{result}')
86
+
87
+ if is_endpoint:
88
+ if result:
89
+ logger.info(f'{segment_id}: {result}')
90
+ segment_id += 1
91
+ recognizer.reset(stream)
92
+
93
+
94
+ async def run_offline(buf, recognizer):
95
+ config = sherpa_onnx.VadModelConfig()
96
+ config.silero_vad.model = os.path.join(
97
+ models_root, 'silero_vad', 'silero_vad.onnx')
98
+ config.silero_vad.min_silence_duration = 0.25
99
+ config.sample_rate = SAMPLE_RATE
100
+ vad = sherpa_onnx.VoiceActivityDetector(
101
+ config, buffer_size_in_seconds=100)
102
+
103
+ logger.info('Start offline recognizer with VAD')
104
+ texts = []
105
+ while True:
106
+ samples = await buf.get()
107
+ vad.accept_waveform(samples)
108
+ while not vad.empty():
109
+ stream = recognizer.create_stream()
110
+ stream.accept_waveform(SAMPLE_RATE, vad.front.samples)
111
+
112
+ vad.pop()
113
+ recognizer.decode_stream(stream)
114
+
115
+ text = stream.result.text.strip().lower()
116
+ if len(text):
117
+ idx = len(texts)
118
+ texts.append(text)
119
+ logger.info(f"{idx}: {text}")
120
+
121
+
122
+ async def handle_asr(args):
123
+ action_func = None
124
+ if args.model == 'zipformer':
125
+ recognizer = create_zipformer(args)
126
+ action_func = run_online
127
+ elif args.model == 'sensevoice':
128
+ recognizer = create_sensevoice(args)
129
+ action_func = run_offline
130
+ else:
131
+ raise ValueError(f'Unknown model: {args.model}')
132
+ buf = asyncio.Queue()
133
+ recorder_task = asyncio.create_task(run_record(buf))
134
+ asr_task = asyncio.create_task(action_func(buf, recognizer))
135
+ await asyncio.gather(asr_task, recorder_task)
136
+
137
+
138
+ async def handle_tts(args):
139
+ model = os.path.join(
140
+ models_root, 'vits-melo-tts-zh_en', 'model.onnx')
141
+ lexicon = os.path.join(
142
+ models_root, 'vits-melo-tts-zh_en', 'lexicon.txt')
143
+ dict_dir = os.path.join(
144
+ models_root, 'vits-melo-tts-zh_en', 'dict')
145
+ tokens = os.path.join(
146
+ models_root, 'vits-melo-tts-zh_en', 'tokens.txt')
147
+ tts_config = sherpa_onnx.OfflineTtsConfig(
148
+ model=sherpa_onnx.OfflineTtsModelConfig(
149
+ vits=sherpa_onnx.OfflineTtsVitsModelConfig(
150
+ model=model,
151
+ lexicon=lexicon,
152
+ dict_dir=dict_dir,
153
+ tokens=tokens,
154
+ ),
155
+ provider=args.provider,
156
+ debug=0,
157
+ num_threads=num_threads,
158
+ ),
159
+ max_num_sentences=args.max_num_sentences,
160
+ )
161
+ if not tts_config.validate():
162
+ raise ValueError("Please check your config")
163
+
164
+ tts = sherpa_onnx.OfflineTts(tts_config)
165
+
166
+ start = time.time()
167
+ audio = tts.generate(args.text, sid=args.sid,
168
+ speed=args.speed)
169
+ elapsed_seconds = time.time() - start
170
+ audio_duration = len(audio.samples) / audio.sample_rate
171
+ real_time_factor = elapsed_seconds / audio_duration
172
+
173
+ if args.output:
174
+ logger.info(f"Saved to {args.output}")
175
+ soundfile.write(
176
+ args.output,
177
+ audio.samples,
178
+ samplerate=audio.sample_rate,
179
+ subtype="PCM_16",
180
+ )
181
+
182
+ logger.info(f"The text is '{args.text}'")
183
+ logger.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
184
+ logger.info(f"Audio duration in seconds: {audio_duration:.3f}")
185
+ logger.info(
186
+ f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
187
+
188
+
189
+ async def run_record(buf: asyncio.Queue[list[float]]):
190
+ loop = asyncio.get_event_loop()
191
+
192
+ def on_input(in_data, frame_count, time_info, status):
193
+ samples = [
194
+ v/32768.0 for v in list(struct.unpack('<' + 'h' * frame_count, in_data))]
195
+ loop.create_task(buf.put(samples))
196
+ return (None, pyaudio.paContinue)
197
+
198
+ frame_size = 320
199
+ recorder = pactx.open(format=pyaudio.paInt16, channels=1,
200
+ rate=SAMPLE_RATE, input=True,
201
+ frames_per_buffer=frame_size,
202
+ stream_callback=on_input)
203
+ recorder.start_stream()
204
+ logger.info('Start recording')
205
+
206
+ while recorder.is_active():
207
+ await asyncio.sleep(0.1)
208
+
209
+
210
+ async def main():
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument('--provider', default='cpu',
213
+ help='onnxruntime provider, default is cpu, use cuda for GPU')
214
+
215
+ subparsers = parser.add_subparsers(help='commands help')
216
+
217
+ asr_parser = subparsers.add_parser('asr', help='run asr mode')
218
+ asr_parser.add_argument('--model', default='zipformer',
219
+ help='model name, default is zipformer')
220
+ asr_parser.add_argument('--lang', default='zh',
221
+ help='language, default is zh')
222
+ asr_parser.set_defaults(func=handle_asr)
223
+
224
+ tts_parser = subparsers.add_parser('tts', help='run tts mode')
225
+ tts_parser.add_argument('--sid', type=int, default=0, help="""Speaker ID. Used only for multi-speaker models, e.g.
226
+ models trained using the VCTK dataset. Not used for single-speaker
227
+ models, e.g., models trained using the LJ speech dataset.
228
+ """)
229
+ tts_parser.add_argument('--output', type=str, default='output.wav',
230
+ help='output file name, default is output.wav')
231
+ tts_parser.add_argument(
232
+ "--speed",
233
+ type=float,
234
+ default=1.0,
235
+ help="Speech speed. Larger->faster; smaller->slower",
236
+ )
237
+ tts_parser.add_argument(
238
+ "--max-num-sentences",
239
+ type=int,
240
+ default=2,
241
+ help="""Max number of sentences in a batch to avoid OOM if the input
242
+ text is very long. Set it to -1 to process all the sentences in a
243
+ single batch. A smaller value does not mean it is slower compared
244
+ to a larger one on CPU.
245
+ """,
246
+ )
247
+ tts_parser.add_argument(
248
+ "text",
249
+ type=str,
250
+ help="The input text to generate audio for",
251
+ )
252
+ tts_parser.set_defaults(func=handle_tts)
253
+
254
+ args = parser.parse_args()
255
+
256
+ if hasattr(args, 'func'):
257
+ await args.func(args)
258
+ else:
259
+ parser.print_help()
260
+
261
+ if __name__ == '__main__':
262
+ logging.basicConfig(
263
+ format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s')
264
+ logging.getLogger().setLevel(logging.INFO)
265
+ painfo = pactx.get_default_input_device_info()
266
+ assert painfo['maxInputChannels'] >= 1, 'No input device'
267
+ logger.info('Default input device: %s', painfo['name'])
268
+
269
+ for d in ['.', '..', '../..']:
270
+ if os.path.isdir(f'{d}/models'):
271
+ models_root = f'{d}/models'
272
+ break
273
+ assert models_root, 'Could not find models directory'
274
+ asyncio.run(main())
speaking.svg ADDED
tts.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import os
3
+ import time
4
+ import sherpa_onnx
5
+ import logging
6
+ import numpy as np
7
+ import asyncio
8
+ import time
9
+ import soundfile
10
+ from scipy.signal import resample
11
+ import io
12
+ import re
13
+
14
+ logger = logging.getLogger(__file__)
15
+
16
+ splitter = re.compile(r'[,,。.!?!?;;、\n]')
17
+ _tts_engines = {}
18
+
19
+ tts_configs = {
20
+ 'vits-zh-hf-theresa': {
21
+ 'model': 'theresa.onnx',
22
+ 'lexicon': 'lexicon.txt',
23
+ 'dict_dir': 'dict',
24
+ 'tokens': 'tokens.txt',
25
+ 'sample_rate': 22050,
26
+ # 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
27
+ },
28
+ 'vits-melo-tts-zh_en': {
29
+ 'model': 'model.onnx',
30
+ 'lexicon': 'lexicon.txt',
31
+ 'dict_dir': 'dict',
32
+ 'tokens': 'tokens.txt',
33
+ 'sample_rate': 44100,
34
+ 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
35
+ },
36
+ }
37
+
38
+
39
+ def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig:
40
+ cfg = tts_configs[name]
41
+ fsts = []
42
+ model_dir = os.path.join(model_root, name)
43
+ for f in cfg.get('rule_fsts', ''):
44
+ fsts.append(os.path.join(model_dir, f))
45
+ tts_rule_fsts = ','.join(fsts) if fsts else ''
46
+
47
+ model_config = sherpa_onnx.OfflineTtsModelConfig(
48
+ vits=sherpa_onnx.OfflineTtsVitsModelConfig(
49
+ model=os.path.join(model_dir, cfg['model']),
50
+ lexicon=os.path.join(model_dir, cfg['lexicon']),
51
+ dict_dir=os.path.join(model_dir, cfg['dict_dir']),
52
+ tokens=os.path.join(model_dir, cfg['tokens']),
53
+ ),
54
+ provider=provider,
55
+ debug=0,
56
+ num_threads=num_threads,
57
+ )
58
+ tts_config = sherpa_onnx.OfflineTtsConfig(
59
+ model=model_config,
60
+ rule_fsts=tts_rule_fsts,
61
+ max_num_sentences=max_num_sentences)
62
+
63
+ if not tts_config.validate():
64
+ raise ValueError("tts: invalid config")
65
+
66
+ return tts_config
67
+
68
+
69
+ def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]:
70
+ sample_rate = tts_configs[args.tts_model]['sample_rate']
71
+ cache_engine = _tts_engines.get(args.tts_model)
72
+ if cache_engine:
73
+ return cache_engine, sample_rate
74
+ st = time.time()
75
+ tts_config = load_tts_model(
76
+ args.tts_model, args.models_root, args.tts_provider)
77
+
78
+ cache_engine = sherpa_onnx.OfflineTts(tts_config)
79
+ elapsed = time.time() - st
80
+ logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s")
81
+ _tts_engines[args.tts_model] = cache_engine
82
+
83
+ return cache_engine, sample_rate
84
+
85
+
86
+ class TTSResult:
87
+ def __init__(self, pcm_bytes: bytes, finished: bool):
88
+ self.pcm_bytes = pcm_bytes
89
+ self.finished = finished
90
+ self.progress: float = 0.0
91
+ self.elapsed: float = 0.0
92
+ self.audio_duration: float = 0.0
93
+ self.audio_size: int = 0
94
+
95
+ def to_dict(self):
96
+ return {
97
+ "progress": self.progress,
98
+ "elapsed": f'{int(self.elapsed * 1000)}ms',
99
+ "duration": f'{self.audio_duration:.2f}s',
100
+ "size": self.audio_size
101
+ }
102
+
103
+
104
+ class TTSStream:
105
+ def __init__(self, engine, sid: int, speed: float = 1.0, sample_rate: int = 16000, original_sample_rate: int = 16000):
106
+ self.engine = engine
107
+ self.sid = sid
108
+ self.speed = speed
109
+ self.outbuf: asyncio.Queue[TTSResult | None] = asyncio.Queue()
110
+ self.is_closed = False
111
+ self.target_sample_rate = sample_rate
112
+ self.original_sample_rate = original_sample_rate
113
+
114
+ def on_process(self, chunk: np.ndarray, progress: float):
115
+ if self.is_closed:
116
+ return 0
117
+
118
+ # resample to target sample rate
119
+ if self.target_sample_rate != self.original_sample_rate:
120
+ num_samples = int(
121
+ len(chunk) * self.target_sample_rate / self.original_sample_rate)
122
+ resampled_chunk = resample(chunk, num_samples)
123
+ chunk = resampled_chunk.astype(np.float32)
124
+
125
+ scaled_chunk = chunk * 32768.0
126
+ clipped_chunk = np.clip(scaled_chunk, -32768, 32767)
127
+ int16_chunk = clipped_chunk.astype(np.int16)
128
+ samples = int16_chunk.tobytes()
129
+ self.outbuf.put_nowait(TTSResult(samples, False))
130
+ return self.is_closed and 0 or 1
131
+
132
+ async def write(self, text: str, split: bool, pause: float = 0.2):
133
+ start = time.time()
134
+ if split:
135
+ texts = re.split(splitter, text)
136
+ else:
137
+ texts = [text]
138
+
139
+ audio_duration = 0.0
140
+ audio_size = 0
141
+
142
+ for idx, text in enumerate(texts):
143
+ text = text.strip()
144
+ if not text:
145
+ continue
146
+ sub_start = time.time()
147
+
148
+ audio = await asyncio.to_thread(self.engine.generate,
149
+ text, self.sid, self.speed,
150
+ self.on_process)
151
+
152
+ if not audio or not audio.sample_rate or not audio.samples:
153
+ logger.error(f"tts: failed to generate audio for "
154
+ f"'{text}' (audio={audio})")
155
+ continue
156
+
157
+ if split and idx < len(texts) - 1: # add a pause between sentences
158
+ noise = np.zeros(int(audio.sample_rate * pause))
159
+ self.on_process(noise, 1.0)
160
+ audio.samples = np.concatenate([audio.samples, noise])
161
+
162
+ audio_duration += len(audio.samples) / audio.sample_rate
163
+ audio_size += len(audio.samples)
164
+ elapsed_seconds = time.time() - sub_start
165
+ logger.info(f"tts: generated audio for '{text}', "
166
+ f"audio duration: {audio_duration:.2f}s, "
167
+ f"elapsed: {elapsed_seconds:.2f}s")
168
+
169
+ elapsed_seconds = time.time() - start
170
+ logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
171
+ f"audio duration: {audio_duration:.2f}s")
172
+
173
+ r = TTSResult(None, True)
174
+ r.elapsed = elapsed_seconds
175
+ r.audio_duration = audio_duration
176
+ r.progress = 1.0
177
+ r.finished = True
178
+ await self.outbuf.put(r)
179
+
180
+ async def close(self):
181
+ self.is_closed = True
182
+ self.outbuf.put_nowait(None)
183
+ logger.info("tts: stream closed")
184
+
185
+ async def read(self) -> TTSResult:
186
+ return await self.outbuf.get()
187
+
188
+ async def generate(self, text: str) -> io.BytesIO:
189
+ start = time.time()
190
+ audio = await asyncio.to_thread(self.engine.generate,
191
+ text, self.sid, self.speed)
192
+ elapsed_seconds = time.time() - start
193
+ audio_duration = len(audio.samples) / audio.sample_rate
194
+
195
+ logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
196
+ f"audio duration: {audio_duration:.2f}s, "
197
+ f"sample rate: {audio.sample_rate}")
198
+
199
+ if self.target_sample_rate != audio.sample_rate:
200
+ audio.samples = resample(audio.samples,
201
+ int(len(audio.samples) * self.target_sample_rate / audio.sample_rate))
202
+ audio.sample_rate = self.target_sample_rate
203
+
204
+ output = io.BytesIO()
205
+ soundfile.write(output,
206
+ audio.samples,
207
+ samplerate=audio.sample_rate,
208
+ subtype="PCM_16",
209
+ format="WAV")
210
+ output.seek(0)
211
+ return output
212
+
213
+
214
+ async def start_tts_stream(sid: int, sample_rate: int, speed: float, args) -> TTSStream:
215
+ engine, original_sample_rate = get_tts_engine(args)
216
+ return TTSStream(engine, sid, speed, sample_rate, original_sample_rate)
voice.png ADDED