Update api_v2.py
Browse files
api_v2.py
CHANGED
@@ -1,500 +1,432 @@
|
|
|
|
1 |
"""
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
GET:
|
17 |
-
```
|
18 |
-
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
|
19 |
-
```
|
20 |
-
|
21 |
-
POST:
|
22 |
-
```json
|
23 |
-
{
|
24 |
-
"text": "", # str.(required) text to be synthesized
|
25 |
-
"text_lang: "", # str.(required) language of the text to be synthesized
|
26 |
-
"ref_audio_path": "", # str.(required) reference audio path
|
27 |
-
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
28 |
-
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
29 |
-
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
30 |
-
"top_k": 5, # int. top k sampling
|
31 |
-
"top_p": 1, # float. top p sampling
|
32 |
-
"temperature": 1, # float. temperature for sampling
|
33 |
-
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
34 |
-
"batch_size": 1, # int. batch size for inference
|
35 |
-
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
36 |
-
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
37 |
-
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
38 |
-
"streaming_mode": False, # bool. whether to return a streaming response.
|
39 |
-
"seed": -1, # int. random seed for reproducibility.
|
40 |
-
"parallel_infer": True, # bool. whether to use parallel inference.
|
41 |
-
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
42 |
-
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
43 |
-
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
44 |
-
}
|
45 |
-
```
|
46 |
-
|
47 |
-
RESP:
|
48 |
-
成功: 直接返回 wav 音频流, http code 200
|
49 |
-
失败: 返回包含错误信息的 json, http code 400
|
50 |
-
|
51 |
-
### 命令控制
|
52 |
-
|
53 |
-
endpoint: `/control`
|
54 |
-
|
55 |
-
command:
|
56 |
-
"restart": 重新运行
|
57 |
-
"exit": 结束运行
|
58 |
-
|
59 |
-
GET:
|
60 |
-
```
|
61 |
-
http://127.0.0.1:9880/control?command=restart
|
62 |
-
```
|
63 |
-
POST:
|
64 |
-
```json
|
65 |
-
{
|
66 |
-
"command": "restart"
|
67 |
-
}
|
68 |
-
```
|
69 |
-
|
70 |
-
RESP: 无
|
71 |
-
|
72 |
-
|
73 |
-
### 切换GPT模型
|
74 |
-
|
75 |
-
endpoint: `/set_gpt_weights`
|
76 |
-
|
77 |
-
GET:
|
78 |
-
```
|
79 |
-
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
80 |
-
```
|
81 |
-
RESP:
|
82 |
-
成功: 返回"success", http code 200
|
83 |
-
失败: 返回包含错误信息的 json, http code 400
|
84 |
-
|
85 |
-
|
86 |
-
### 切换Sovits模型
|
87 |
-
|
88 |
-
endpoint: `/set_sovits_weights`
|
89 |
-
|
90 |
-
GET:
|
91 |
-
```
|
92 |
-
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
|
93 |
-
```
|
94 |
-
|
95 |
-
RESP:
|
96 |
-
成功: 返回"success", http code 200
|
97 |
-
失败: 返回包含错误信息的 json, http code 400
|
98 |
-
|
99 |
"""
|
100 |
|
|
|
|
|
|
|
101 |
import os
|
|
|
|
|
102 |
import sys
|
103 |
import traceback
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
sys.path.append(now_dir)
|
108 |
-
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
109 |
|
110 |
-
import argparse
|
111 |
-
import subprocess
|
112 |
-
import wave
|
113 |
-
import signal
|
114 |
import numpy as np
|
|
|
115 |
import soundfile as sf
|
116 |
-
|
117 |
-
from fastapi.responses import StreamingResponse, JSONResponse
|
118 |
import uvicorn
|
119 |
-
from
|
120 |
-
from
|
121 |
-
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
122 |
-
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
123 |
from pydantic import BaseModel
|
124 |
|
125 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
i18n = I18nAuto()
|
127 |
cut_method_names = get_cut_method_names()
|
128 |
|
129 |
-
parser = argparse.ArgumentParser(description="GPT
|
130 |
-
parser.add_argument(
|
131 |
-
|
132 |
-
|
|
|
|
|
133 |
args = parser.parse_args()
|
134 |
-
config_path = args.tts_config
|
135 |
-
# device = args.device
|
136 |
-
port = args.port
|
137 |
-
host = args.bind_addr
|
138 |
-
argv = sys.argv
|
139 |
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
tts_config = TTS_Config(config_path)
|
144 |
print(tts_config)
|
145 |
-
|
146 |
|
147 |
APP = FastAPI()
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
-
class TTS_Request(BaseModel):
|
151 |
-
text: str = None
|
152 |
-
text_lang: str = None
|
153 |
-
ref_audio_path: str = None
|
154 |
-
aux_ref_audio_paths: list = None
|
155 |
-
prompt_lang: str = None
|
156 |
-
prompt_text: str = ""
|
157 |
-
top_k: int = 5
|
158 |
-
top_p: float = 1
|
159 |
-
temperature: float = 1
|
160 |
-
text_split_method: str = "cut5"
|
161 |
-
batch_size: int = 1
|
162 |
-
batch_threshold: float = 0.75
|
163 |
-
split_bucket: bool = True
|
164 |
-
speed_factor: float = 1.0
|
165 |
-
fragment_interval: float = 0.3
|
166 |
-
seed: int = -1
|
167 |
-
media_type: str = "wav"
|
168 |
-
streaming_mode: bool = False
|
169 |
-
parallel_infer: bool = True
|
170 |
-
repetition_penalty: float = 1.35
|
171 |
-
sample_steps: int = 32
|
172 |
-
super_sampling: bool = False
|
173 |
|
|
|
|
|
|
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
return io_buffer
|
180 |
|
181 |
|
182 |
-
def
|
183 |
-
|
184 |
-
return
|
185 |
|
186 |
|
187 |
-
def
|
188 |
-
|
189 |
-
|
190 |
-
return io_buffer
|
191 |
|
192 |
|
193 |
-
def
|
194 |
-
|
195 |
[
|
196 |
"ffmpeg",
|
197 |
"-f",
|
198 |
-
"s16le",
|
199 |
"-ar",
|
200 |
-
str(rate),
|
201 |
"-ac",
|
202 |
-
"1",
|
203 |
"-i",
|
204 |
-
"pipe:0",
|
205 |
"-c:a",
|
206 |
-
"aac",
|
207 |
"-b:a",
|
208 |
-
"192k",
|
209 |
-
"-vn",
|
210 |
"-f",
|
211 |
-
"adts",
|
212 |
-
"pipe:1",
|
213 |
],
|
214 |
stdin=subprocess.PIPE,
|
215 |
stdout=subprocess.PIPE,
|
216 |
stderr=subprocess.PIPE,
|
217 |
)
|
218 |
-
out, _ =
|
219 |
-
|
220 |
-
return
|
221 |
-
|
222 |
-
|
223 |
-
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
|
224 |
-
if media_type == "ogg":
|
225 |
-
io_buffer = pack_ogg(io_buffer, data, rate)
|
226 |
-
elif media_type == "aac":
|
227 |
-
io_buffer = pack_aac(io_buffer, data, rate)
|
228 |
-
elif media_type == "wav":
|
229 |
-
io_buffer = pack_wav(io_buffer, data, rate)
|
230 |
-
else:
|
231 |
-
io_buffer = pack_raw(io_buffer, data, rate)
|
232 |
-
io_buffer.seek(0)
|
233 |
-
return io_buffer
|
234 |
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
vfout.writeframes(frame_input)
|
247 |
|
248 |
-
wav_buf.seek(0)
|
249 |
-
return wav_buf.read()
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
-
def handle_control(command: str):
|
253 |
-
if command == "restart":
|
254 |
-
os.execl(sys.executable, sys.executable, *argv)
|
255 |
-
elif command == "exit":
|
256 |
-
os.kill(os.getpid(), signal.SIGTERM)
|
257 |
-
exit(0)
|
258 |
-
|
259 |
-
|
260 |
-
def check_params(req: dict):
|
261 |
-
text: str = req.get("text", "")
|
262 |
-
text_lang: str = req.get("text_lang", "")
|
263 |
-
ref_audio_path: str = req.get("ref_audio_path", "")
|
264 |
-
streaming_mode: bool = req.get("streaming_mode", False)
|
265 |
-
media_type: str = req.get("media_type", "wav")
|
266 |
-
prompt_lang: str = req.get("prompt_lang", "")
|
267 |
-
text_split_method: str = req.get("text_split_method", "cut5")
|
268 |
-
|
269 |
-
if ref_audio_path in [None, ""]:
|
270 |
-
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
271 |
-
if text in [None, ""]:
|
272 |
-
return JSONResponse(status_code=400, content={"message": "text is required"})
|
273 |
-
if text_lang in [None, ""]:
|
274 |
-
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
|
275 |
-
elif text_lang.lower() not in tts_config.languages:
|
276 |
-
return JSONResponse(
|
277 |
-
status_code=400,
|
278 |
-
content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"},
|
279 |
-
)
|
280 |
-
if prompt_lang in [None, ""]:
|
281 |
-
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
282 |
-
elif prompt_lang.lower() not in tts_config.languages:
|
283 |
-
return JSONResponse(
|
284 |
-
status_code=400,
|
285 |
-
content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"},
|
286 |
-
)
|
287 |
-
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
288 |
-
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
|
289 |
-
elif media_type == "ogg" and not streaming_mode:
|
290 |
-
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
291 |
-
|
292 |
-
if text_split_method not in cut_method_names:
|
293 |
-
return JSONResponse(
|
294 |
-
status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}
|
295 |
-
)
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
return None
|
298 |
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
"text_lang: "", # str.(required) language of the text to be synthesized
|
309 |
-
"ref_audio_path": "", # str.(required) reference audio path
|
310 |
-
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
311 |
-
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
312 |
-
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
313 |
-
"top_k": 5, # int. top k sampling
|
314 |
-
"top_p": 1, # float. top p sampling
|
315 |
-
"temperature": 1, # float. temperature for sampling
|
316 |
-
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
317 |
-
"batch_size": 1, # int. batch size for inference
|
318 |
-
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
319 |
-
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
320 |
-
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
321 |
-
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
322 |
-
"seed": -1, # int. random seed for reproducibility.
|
323 |
-
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
324 |
-
"streaming_mode": False, # bool. whether to return a streaming response.
|
325 |
-
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
326 |
-
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
327 |
-
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
328 |
-
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
329 |
-
}
|
330 |
-
returns:
|
331 |
-
StreamingResponse: audio stream response.
|
332 |
-
"""
|
333 |
|
334 |
streaming_mode = req.get("streaming_mode", False)
|
335 |
-
return_fragment = req.get("return_fragment", False)
|
336 |
media_type = req.get("media_type", "wav")
|
337 |
|
338 |
-
|
339 |
-
if check_res is not None:
|
340 |
-
return check_res
|
341 |
-
|
342 |
-
if streaming_mode or return_fragment:
|
343 |
-
req["return_fragment"] = True
|
344 |
-
|
345 |
try:
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
if streaming_mode:
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
|
383 |
@APP.get("/tts")
|
384 |
-
async def
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
prompt_text: str = "",
|
391 |
-
top_k: int = 5,
|
392 |
-
top_p: float = 1,
|
393 |
-
temperature: float = 1,
|
394 |
-
text_split_method: str = "cut0",
|
395 |
-
batch_size: int = 1,
|
396 |
-
batch_threshold: float = 0.75,
|
397 |
-
split_bucket: bool = True,
|
398 |
-
speed_factor: float = 1.0,
|
399 |
-
fragment_interval: float = 0.3,
|
400 |
-
seed: int = -1,
|
401 |
-
media_type: str = "wav",
|
402 |
-
streaming_mode: bool = False,
|
403 |
-
parallel_infer: bool = True,
|
404 |
-
repetition_penalty: float = 1.35,
|
405 |
-
sample_steps: int = 32,
|
406 |
-
super_sampling: bool = False,
|
407 |
-
):
|
408 |
-
req = {
|
409 |
-
"text": text,
|
410 |
-
"text_lang": text_lang.lower(),
|
411 |
-
"ref_audio_path": ref_audio_path,
|
412 |
-
"aux_ref_audio_paths": aux_ref_audio_paths,
|
413 |
-
"prompt_text": prompt_text,
|
414 |
-
"prompt_lang": prompt_lang.lower(),
|
415 |
-
"top_k": top_k,
|
416 |
-
"top_p": top_p,
|
417 |
-
"temperature": temperature,
|
418 |
-
"text_split_method": text_split_method,
|
419 |
-
"batch_size": int(batch_size),
|
420 |
-
"batch_threshold": float(batch_threshold),
|
421 |
-
"speed_factor": float(speed_factor),
|
422 |
-
"split_bucket": split_bucket,
|
423 |
-
"fragment_interval": fragment_interval,
|
424 |
-
"seed": seed,
|
425 |
-
"media_type": media_type,
|
426 |
-
"streaming_mode": streaming_mode,
|
427 |
-
"parallel_infer": parallel_infer,
|
428 |
-
"repetition_penalty": float(repetition_penalty),
|
429 |
-
"sample_steps": int(sample_steps),
|
430 |
-
"super_sampling": super_sampling,
|
431 |
-
}
|
432 |
-
return await tts_handle(req)
|
433 |
|
434 |
|
435 |
@APP.post("/tts")
|
436 |
-
async def
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
|
|
439 |
|
440 |
|
441 |
-
@APP.get("/
|
442 |
-
async def
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
449 |
|
450 |
-
# @APP.post("/set_refer_audio")
|
451 |
-
# async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
|
452 |
-
# try:
|
453 |
-
# # 检查文件类型,确保是音频文件
|
454 |
-
# if not audio_file.content_type.startswith("audio/"):
|
455 |
-
# return JSONResponse(status_code=400, content={"message": "file type is not supported"})
|
456 |
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
# buffer.write(await audio_file.read())
|
462 |
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
|
469 |
@APP.get("/set_gpt_weights")
|
470 |
-
async def set_gpt_weights(weights_path: str = None):
|
|
|
|
|
471 |
try:
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
|
477 |
-
|
478 |
-
return JSONResponse(status_code=200, content={"message": "success"})
|
479 |
|
480 |
|
481 |
@APP.get("/set_sovits_weights")
|
482 |
-
async def set_sovits_weights(weights_path: str = None):
|
|
|
|
|
483 |
try:
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
return JSONResponse(status_code=200, content={"message": "success"})
|
490 |
|
|
|
|
|
|
|
491 |
|
492 |
if __name__ == "__main__":
|
493 |
try:
|
494 |
-
|
495 |
-
|
496 |
-
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
497 |
-
except Exception:
|
498 |
traceback.print_exc()
|
499 |
os.kill(os.getpid(), signal.SIGTERM)
|
500 |
-
exit(0)
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
"""
|
3 |
+
Updated FastAPI backend for GPT-SoVITS (*April 2025*)
|
4 |
+
---------------------------------------------------
|
5 |
+
Changes compared with the previous version shipped on 30 Apr 2025
|
6 |
+
=================================================================
|
7 |
+
1. **URL / S3 audio support** — `process_audio_path()` downloads `ref_audio_path` and
|
8 |
+
each entry in `aux_ref_audio_paths` when they are HTTP(S) or S3 URLs, storing them
|
9 |
+
as temporary files that are cleaned up afterwards.
|
10 |
+
2. **CUDA memory hygiene** — `torch.cuda.empty_cache()` is invoked after each request
|
11 |
+
(success *or* error) to release GPU memory.
|
12 |
+
3. **Temporary‑file cleanup** — all files created by `process_audio_path()` are
|
13 |
+
removed in `finally` blocks so they are guaranteed to disappear no matter how the
|
14 |
+
request terminates.
|
15 |
+
|
16 |
+
The public API surface (**end‑points and query parameters**) is unchanged.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
"""
|
18 |
|
19 |
+
from __future__ import annotations
|
20 |
+
|
21 |
+
import argparse
|
22 |
import os
|
23 |
+
import signal
|
24 |
+
import subprocess
|
25 |
import sys
|
26 |
import traceback
|
27 |
+
import urllib.parse
|
28 |
+
from io import BytesIO
|
29 |
+
from typing import Generator, List, Tuple
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
31 |
import numpy as np
|
32 |
+
import requests
|
33 |
import soundfile as sf
|
34 |
+
import torch
|
|
|
35 |
import uvicorn
|
36 |
+
from fastapi import FastAPI, HTTPException, Response
|
37 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
|
38 |
from pydantic import BaseModel
|
39 |
|
40 |
+
# ---------------------------------------------------------------------------
|
41 |
+
# Local package imports – keep *after* sys.path manipulation so relative import
|
42 |
+
# resolution continues to work when this file is executed from any directory.
|
43 |
+
# ---------------------------------------------------------------------------
|
44 |
+
|
45 |
+
NOW_DIR = os.getcwd()
|
46 |
+
sys.path.extend([NOW_DIR, f"{NOW_DIR}/GPT_SoVITS"])
|
47 |
+
|
48 |
+
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
49 |
+
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
|
50 |
+
get_method_names as get_cut_method_names,
|
51 |
+
)
|
52 |
+
from tools.i18n.i18n import I18nAuto # noqa: E402
|
53 |
+
|
54 |
+
# ---------------------------------------------------------------------------
|
55 |
+
# CLI arguments & global objects
|
56 |
+
# ---------------------------------------------------------------------------
|
57 |
+
|
58 |
i18n = I18nAuto()
|
59 |
cut_method_names = get_cut_method_names()
|
60 |
|
61 |
+
parser = argparse.ArgumentParser(description="GPT‑SoVITS API")
|
62 |
+
parser.add_argument(
|
63 |
+
"-c", "--tts_config", default="GPT_SoVITS/configs/tts_infer.yaml", help="TTS‑infer config path"
|
64 |
+
)
|
65 |
+
parser.add_argument("-a", "--bind_addr", default="127.0.0.1", help="Bind address (default 127.0.0.1)")
|
66 |
+
parser.add_argument("-p", "--port", type=int, default=9880, help="Port (default 9880)")
|
67 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
config_path = args.tts_config or "GPT-SoVITS/configs/tts_infer.yaml"
|
70 |
+
PORT = args.port
|
71 |
+
HOST = None if args.bind_addr == "None" else args.bind_addr
|
72 |
+
|
73 |
+
# ---------------------------------------------------------------------------
|
74 |
+
# TTS initialisation
|
75 |
+
# ---------------------------------------------------------------------------
|
76 |
|
77 |
tts_config = TTS_Config(config_path)
|
78 |
print(tts_config)
|
79 |
+
TTS_PIPELINE = TTS(tts_config)
|
80 |
|
81 |
APP = FastAPI()
|
82 |
|
83 |
+
# ---------------------------------------------------------------------------
|
84 |
+
# Helper utilities
|
85 |
+
# ---------------------------------------------------------------------------
|
86 |
+
|
87 |
+
TEMP_DIR = os.path.join(NOW_DIR, "_tmp_audio")
|
88 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
89 |
+
|
90 |
+
def _empty_cuda_cache() -> None:
|
91 |
+
"""Release GPU memory if CUDA is available."""
|
92 |
+
if torch.cuda.is_available():
|
93 |
+
torch.cuda.empty_cache()
|
94 |
+
|
95 |
+
|
96 |
+
def _download_to_temp(url: str) -> str:
|
97 |
+
"""Download *url* to a unique file inside ``TEMP_DIR`` and return the path."""
|
98 |
+
parsed = urllib.parse.urlparse(url)
|
99 |
+
filename = os.path.basename(parsed.path) or f"audio_{abs(hash(url))}.wav"
|
100 |
+
local_path = os.path.join(TEMP_DIR, filename)
|
101 |
+
|
102 |
+
if url.startswith("s3://"):
|
103 |
+
# Lazy‑load boto3 if/when the first S3 request arrives.
|
104 |
+
import importlib
|
105 |
+
|
106 |
+
boto3 = importlib.import_module("boto3") # pylint: disable=import-error
|
107 |
+
s3_client = boto3.client("s3")
|
108 |
+
s3_client.download_file(parsed.netloc, parsed.path.lstrip("/"), local_path)
|
109 |
+
else:
|
110 |
+
with requests.get(url, stream=True, timeout=30) as r:
|
111 |
+
r.raise_for_status()
|
112 |
+
with open(local_path, "wb") as f_out:
|
113 |
+
for chunk in r.iter_content(chunk_size=8192):
|
114 |
+
f_out.write(chunk)
|
115 |
+
|
116 |
+
return local_path
|
117 |
+
|
118 |
+
|
119 |
+
def process_audio_path(audio_path: str | None) -> Tuple[str | None, bool]:
|
120 |
+
"""Return a *local* path for *audio_path* and whether it is temporary."""
|
121 |
+
if not audio_path:
|
122 |
+
return audio_path, False
|
123 |
+
|
124 |
+
if audio_path.startswith(("http://", "https://", "s3://")):
|
125 |
+
try:
|
126 |
+
local = _download_to_temp(audio_path)
|
127 |
+
return local, True
|
128 |
+
except Exception as exc: # pragma: no‑cover
|
129 |
+
raise HTTPException(status_code=400, detail=f"Failed to download audio: {exc}") from exc
|
130 |
+
return audio_path, False
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
+
# ---------------------------------------------------------------------------
|
134 |
+
# Audio (de)serialisation helpers
|
135 |
+
# ---------------------------------------------------------------------------
|
136 |
|
137 |
+
def _pack_ogg(buf: BytesIO, data: np.ndarray, rate: int):
|
138 |
+
with sf.SoundFile(buf, mode="w", samplerate=rate, channels=1, format="ogg") as f:
|
139 |
+
f.write(data)
|
140 |
+
return buf
|
|
|
141 |
|
142 |
|
143 |
+
def _pack_raw(buf: BytesIO, data: np.ndarray, _rate: int):
|
144 |
+
buf.write(data.tobytes())
|
145 |
+
return buf
|
146 |
|
147 |
|
148 |
+
def _pack_wav(buf: BytesIO, data: np.ndarray, rate: int):
|
149 |
+
sf.write(buf, data, rate, format="wav")
|
150 |
+
return buf
|
|
|
151 |
|
152 |
|
153 |
+
def _pack_aac(buf: BytesIO, data: np.ndarray, rate: int):
|
154 |
+
proc = subprocess.Popen(
|
155 |
[
|
156 |
"ffmpeg",
|
157 |
"-f",
|
158 |
+
"s16le",
|
159 |
"-ar",
|
160 |
+
str(rate),
|
161 |
"-ac",
|
162 |
+
"1",
|
163 |
"-i",
|
164 |
+
"pipe:0",
|
165 |
"-c:a",
|
166 |
+
"aac",
|
167 |
"-b:a",
|
168 |
+
"192k",
|
169 |
+
"-vn",
|
170 |
"-f",
|
171 |
+
"adts",
|
172 |
+
"pipe:1",
|
173 |
],
|
174 |
stdin=subprocess.PIPE,
|
175 |
stdout=subprocess.PIPE,
|
176 |
stderr=subprocess.PIPE,
|
177 |
)
|
178 |
+
out, _ = proc.communicate(input=data.tobytes())
|
179 |
+
buf.write(out)
|
180 |
+
return buf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
|
183 |
+
def _pack_audio(buf: BytesIO, data: np.ndarray, rate: int, media_type: str):
|
184 |
+
dispatch = {
|
185 |
+
"ogg": _pack_ogg,
|
186 |
+
"aac": _pack_aac,
|
187 |
+
"wav": _pack_wav,
|
188 |
+
"raw": _pack_raw,
|
189 |
+
}
|
190 |
+
buf = dispatch.get(media_type, _pack_raw)(buf, data, rate)
|
191 |
+
buf.seek(0)
|
192 |
+
return buf
|
|
|
193 |
|
|
|
|
|
194 |
|
195 |
+
# ---------------------------------------------------------------------------
|
196 |
+
# Schemas
|
197 |
+
# ---------------------------------------------------------------------------
|
198 |
+
|
199 |
+
class TTSRequest(BaseModel):
|
200 |
+
text: str | None = None
|
201 |
+
text_lang: str | None = None
|
202 |
+
ref_audio_path: str | None = None
|
203 |
+
aux_ref_audio_paths: List[str] | None = None
|
204 |
+
prompt_lang: str | None = None
|
205 |
+
prompt_text: str = ""
|
206 |
+
top_k: int = 5
|
207 |
+
top_p: float = 1.0
|
208 |
+
temperature: float = 1.0
|
209 |
+
text_split_method: str = "cut5"
|
210 |
+
batch_size: int = 1
|
211 |
+
batch_threshold: float = 0.75
|
212 |
+
split_bucket: bool = True
|
213 |
+
speed_factor: float = 1.0
|
214 |
+
fragment_interval: float = 0.3
|
215 |
+
seed: int = -1
|
216 |
+
media_type: str = "wav"
|
217 |
+
streaming_mode: bool = False
|
218 |
+
parallel_infer: bool = True
|
219 |
+
repetition_penalty: float = 1.35
|
220 |
+
sample_steps: int = 32
|
221 |
+
super_sampling: bool = False
|
222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
+
# ---------------------------------------------------------------------------
|
225 |
+
# Validation helpers
|
226 |
+
# ---------------------------------------------------------------------------
|
227 |
+
|
228 |
+
def _validate_request(req: dict):
|
229 |
+
if not req.get("text"):
|
230 |
+
return "text is required"
|
231 |
+
if not req.get("text_lang"):
|
232 |
+
return "text_lang is required"
|
233 |
+
if req["text_lang"].lower() not in tts_config.languages:
|
234 |
+
return f"text_lang {req['text_lang']} not supported"
|
235 |
+
if not req.get("prompt_lang"):
|
236 |
+
return "prompt_lang is required"
|
237 |
+
if req["prompt_lang"].lower() not in tts_config.languages:
|
238 |
+
return f"prompt_lang {req['prompt_lang']} not supported"
|
239 |
+
if not req.get("ref_audio_path"):
|
240 |
+
return "ref_audio_path is required"
|
241 |
+
mt = req.get("media_type", "wav")
|
242 |
+
if mt not in {"wav", "raw", "ogg", "aac"}:
|
243 |
+
return f"media_type {mt} not supported"
|
244 |
+
if (not req.get("streaming_mode") and mt == "ogg"):
|
245 |
+
return "ogg is only supported in streaming mode"
|
246 |
+
if req.get("text_split_method", "cut5") not in cut_method_names:
|
247 |
+
return f"text_split_method {req['text_split_method']} not supported"
|
248 |
return None
|
249 |
|
250 |
|
251 |
+
# ---------------------------------------------------------------------------
|
252 |
+
# Core handler
|
253 |
+
# ---------------------------------------------------------------------------
|
254 |
+
|
255 |
+
async def _tts_handle(req: dict):
|
256 |
+
error = _validate_request(req)
|
257 |
+
if error:
|
258 |
+
return JSONResponse(status_code=400, content={"message": error})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
streaming_mode = req.get("streaming_mode", False)
|
|
|
261 |
media_type = req.get("media_type", "wav")
|
262 |
|
263 |
+
temp_files: List[str] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
try:
|
265 |
+
# --- resolve & download audio paths ----------------------------------
|
266 |
+
ref_path, is_tmp = process_audio_path(req["ref_audio_path"])
|
267 |
+
req["ref_audio_path"] = ref_path
|
268 |
+
if is_tmp:
|
269 |
+
temp_files.append(ref_path)
|
270 |
+
|
271 |
+
if req.get("aux_ref_audio_paths"):
|
272 |
+
resolved: List[str] = []
|
273 |
+
for p in req["aux_ref_audio_paths"]:
|
274 |
+
lp, tmp = process_audio_path(p)
|
275 |
+
resolved.append(lp)
|
276 |
+
if tmp:
|
277 |
+
temp_files.append(lp)
|
278 |
+
req["aux_ref_audio_paths"] = resolved
|
279 |
+
|
280 |
+
# --- run inference ----------------------------------------------------
|
281 |
+
generator = TTS_PIPELINE.run(req)
|
282 |
|
283 |
if streaming_mode:
|
284 |
+
async def _gen(gen: Generator, _media_type: str):
|
285 |
+
first = True
|
286 |
+
try:
|
287 |
+
for sr, chunk in gen:
|
288 |
+
if first and _media_type == "wav":
|
289 |
+
# Prepend a WAV header so clients can play immediately.
|
290 |
+
header = _wave_header_chunk(sample_rate=sr)
|
291 |
+
yield header
|
292 |
+
_media_type = "raw"
|
293 |
+
first = False
|
294 |
+
yield _pack_audio(BytesIO(), chunk, sr, _media_type).getvalue()
|
295 |
+
finally:
|
296 |
+
_cleanup(temp_files)
|
297 |
+
return StreamingResponse(_gen(generator, media_type), media_type=f"audio/{media_type}")
|
298 |
+
|
299 |
+
# non‑streaming --------------------------------------------------------
|
300 |
+
sr, data = next(generator)
|
301 |
+
payload = _pack_audio(BytesIO(), data, sr, media_type).getvalue()
|
302 |
+
resp = Response(payload, media_type=f"audio/{media_type}")
|
303 |
+
_cleanup(temp_files)
|
304 |
+
return resp
|
305 |
+
|
306 |
+
except Exception as exc: # noqa: BLE001
|
307 |
+
_cleanup(temp_files)
|
308 |
+
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(exc)})
|
309 |
+
|
310 |
+
|
311 |
+
# ---------------------------------------------------------------------------
|
312 |
+
# Cleanup helpers
|
313 |
+
# ---------------------------------------------------------------------------
|
314 |
+
|
315 |
+
def _cleanup(temp_files: List[str]):
|
316 |
+
for fp in temp_files:
|
317 |
+
try:
|
318 |
+
os.remove(fp)
|
319 |
+
# print(f"[cleanup] removed {fp}")
|
320 |
+
except FileNotFoundError:
|
321 |
+
pass
|
322 |
+
except Exception as exc: # pragma: no‑cover
|
323 |
+
print(f"[cleanup‑warning] {exc}")
|
324 |
+
_empty_cuda_cache()
|
325 |
+
|
326 |
+
|
327 |
+
# ---------------------------------------------------------------------------
|
328 |
+
# WAV header helper (for streaming WAV)
|
329 |
+
# ---------------------------------------------------------------------------
|
330 |
+
|
331 |
+
import wave # placed here to keep top import section tidy
|
332 |
+
|
333 |
+
def _wave_header_chunk(frame: bytes = b"", *, channels: int = 1, width: int = 2, sample_rate: int = 32_000):
|
334 |
+
buf = BytesIO()
|
335 |
+
with wave.open(buf, "wb") as wav:
|
336 |
+
wav.setnchannels(channels)
|
337 |
+
wav.setsampwidth(width)
|
338 |
+
wav.setframerate(sample_rate)
|
339 |
+
wav.writeframes(frame)
|
340 |
+
buf.seek(0)
|
341 |
+
return buf.read()
|
342 |
+
|
343 |
+
|
344 |
+
# ---------------------------------------------------------------------------
|
345 |
+
# End‑points
|
346 |
+
# ---------------------------------------------------------------------------
|
347 |
|
348 |
@APP.get("/tts")
|
349 |
+
async def tts_get(**query):
|
350 |
+
# Normalise language codes to lower‑case where applicable
|
351 |
+
for k in ("text_lang", "prompt_lang"):
|
352 |
+
if k in query and query[k] is not None:
|
353 |
+
query[k] = query[k].lower()
|
354 |
+
return await _tts_handle(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
|
357 |
@APP.post("/tts")
|
358 |
+
async def tts_post(request: TTSRequest):
|
359 |
+
payload = request.dict()
|
360 |
+
if payload.get("text_lang"):
|
361 |
+
payload["text_lang"] = payload["text_lang"].lower()
|
362 |
+
if payload.get("prompt_lang"):
|
363 |
+
payload["prompt_lang"] = payload["prompt_lang"].lower()
|
364 |
+
return await _tts_handle(payload)
|
365 |
|
366 |
|
367 |
+
@APP.get("/control")
|
368 |
+
async def control(command: str | None = None):
|
369 |
+
if not command:
|
370 |
+
raise HTTPException(status_code=400, detail="command is required")
|
371 |
+
if command == "restart":
|
372 |
+
os.execl(sys.executable, sys.executable, *sys.argv)
|
373 |
+
elif command == "exit":
|
374 |
+
os.kill(os.getpid(), signal.SIGTERM)
|
375 |
+
else:
|
376 |
+
raise HTTPException(status_code=400, detail="unsupported command")
|
377 |
+
return {"message": "ok"}
|
378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
+
@APP.get("/set_refer_audio")
|
381 |
+
async def set_refer_audio(refer_audio_path: str | None = None):
|
382 |
+
if not refer_audio_path:
|
383 |
+
return JSONResponse(status_code=400, content={"message": "refer_audio_path is required"})
|
|
|
384 |
|
385 |
+
temp_file = None
|
386 |
+
try:
|
387 |
+
local_path, is_tmp = process_audio_path(refer_audio_path)
|
388 |
+
temp_file = local_path if is_tmp else None
|
389 |
+
TTS_PIPELINE.set_ref_audio(local_path)
|
390 |
+
return {"message": "success"}
|
391 |
+
finally:
|
392 |
+
if temp_file:
|
393 |
+
try:
|
394 |
+
os.remove(temp_file)
|
395 |
+
except FileNotFoundError:
|
396 |
+
pass
|
397 |
+
_empty_cuda_cache()
|
398 |
|
399 |
|
400 |
@APP.get("/set_gpt_weights")
|
401 |
+
async def set_gpt_weights(weights_path: str | None = None):
|
402 |
+
if not weights_path:
|
403 |
+
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
|
404 |
try:
|
405 |
+
TTS_PIPELINE.init_t2s_weights(weights_path)
|
406 |
+
return {"message": "success"}
|
407 |
+
except Exception as exc: # noqa: BLE001
|
408 |
+
return JSONResponse(status_code=400, content={"message": str(exc)})
|
|
|
|
|
|
|
409 |
|
410 |
|
411 |
@APP.get("/set_sovits_weights")
|
412 |
+
async def set_sovits_weights(weights_path: str | None = None):
|
413 |
+
if not weights_path:
|
414 |
+
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
|
415 |
try:
|
416 |
+
TTS_PIPELINE.init_vits_weights(weights_path)
|
417 |
+
return {"message": "success"}
|
418 |
+
except Exception as exc: # noqa: BLE001
|
419 |
+
return JSONResponse(status_code=400, content={"message": str(exc)})
|
420 |
+
|
|
|
421 |
|
422 |
+
# ---------------------------------------------------------------------------
|
423 |
+
# Main entry point
|
424 |
+
# ---------------------------------------------------------------------------
|
425 |
|
426 |
if __name__ == "__main__":
|
427 |
try:
|
428 |
+
uvicorn.run(app=APP, host=HOST, port=PORT, workers=1)
|
429 |
+
except Exception: # pragma: no‑cover
|
|
|
|
|
430 |
traceback.print_exc()
|
431 |
os.kill(os.getpid(), signal.SIGTERM)
|
432 |
+
sys.exit(0)
|