kevinwang676 commited on
Commit
8a8224c
·
verified ·
1 Parent(s): dc3921c

Update api_v2.py

Browse files
Files changed (1) hide show
  1. api_v2.py +346 -414
api_v2.py CHANGED
@@ -1,500 +1,432 @@
 
1
  """
2
- # WebAPI文档
3
-
4
- ` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
5
-
6
- ## 执行参数:
7
- `-a` - `绑定地址, 默认"127.0.0.1"`
8
- `-p` - `绑定端口, 默认9880`
9
- `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
10
-
11
- ## 调用:
12
-
13
- ### 推理
14
-
15
- endpoint: `/tts`
16
- GET:
17
- ```
18
- http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
19
- ```
20
-
21
- POST:
22
- ```json
23
- {
24
- "text": "", # str.(required) text to be synthesized
25
- "text_lang: "", # str.(required) language of the text to be synthesized
26
- "ref_audio_path": "", # str.(required) reference audio path
27
- "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
28
- "prompt_text": "", # str.(optional) prompt text for the reference audio
29
- "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
30
- "top_k": 5, # int. top k sampling
31
- "top_p": 1, # float. top p sampling
32
- "temperature": 1, # float. temperature for sampling
33
- "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
34
- "batch_size": 1, # int. batch size for inference
35
- "batch_threshold": 0.75, # float. threshold for batch splitting.
36
- "split_bucket: True, # bool. whether to split the batch into multiple buckets.
37
- "speed_factor":1.0, # float. control the speed of the synthesized audio.
38
- "streaming_mode": False, # bool. whether to return a streaming response.
39
- "seed": -1, # int. random seed for reproducibility.
40
- "parallel_infer": True, # bool. whether to use parallel inference.
41
- "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
42
- "sample_steps": 32, # int. number of sampling steps for VITS model V3.
43
- "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
44
- }
45
- ```
46
-
47
- RESP:
48
- 成功: 直接返回 wav 音频流, http code 200
49
- 失败: 返回包含错误信息的 json, http code 400
50
-
51
- ### 命令控制
52
-
53
- endpoint: `/control`
54
-
55
- command:
56
- "restart": 重新运行
57
- "exit": 结束运行
58
-
59
- GET:
60
- ```
61
- http://127.0.0.1:9880/control?command=restart
62
- ```
63
- POST:
64
- ```json
65
- {
66
- "command": "restart"
67
- }
68
- ```
69
-
70
- RESP: 无
71
-
72
-
73
- ### 切换GPT模型
74
-
75
- endpoint: `/set_gpt_weights`
76
-
77
- GET:
78
- ```
79
- http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
80
- ```
81
- RESP:
82
- 成功: 返回"success", http code 200
83
- 失败: 返回包含错误信息的 json, http code 400
84
-
85
-
86
- ### 切换Sovits模型
87
-
88
- endpoint: `/set_sovits_weights`
89
-
90
- GET:
91
- ```
92
- http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
93
- ```
94
-
95
- RESP:
96
- 成功: 返回"success", http code 200
97
- 失败: 返回包含错误信息的 json, http code 400
98
-
99
  """
100
 
 
 
 
101
  import os
 
 
102
  import sys
103
  import traceback
104
- from typing import Generator
105
-
106
- now_dir = os.getcwd()
107
- sys.path.append(now_dir)
108
- sys.path.append("%s/GPT_SoVITS" % (now_dir))
109
 
110
- import argparse
111
- import subprocess
112
- import wave
113
- import signal
114
  import numpy as np
 
115
  import soundfile as sf
116
- from fastapi import FastAPI, Response
117
- from fastapi.responses import StreamingResponse, JSONResponse
118
  import uvicorn
119
- from io import BytesIO
120
- from tools.i18n.i18n import I18nAuto
121
- from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
122
- from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
123
  from pydantic import BaseModel
124
 
125
- # print(sys.path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  i18n = I18nAuto()
127
  cut_method_names = get_cut_method_names()
128
 
129
- parser = argparse.ArgumentParser(description="GPT-SoVITS api")
130
- parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
131
- parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
132
- parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
 
 
133
  args = parser.parse_args()
134
- config_path = args.tts_config
135
- # device = args.device
136
- port = args.port
137
- host = args.bind_addr
138
- argv = sys.argv
139
 
140
- if config_path in [None, ""]:
141
- config_path = "GPT-SoVITS/configs/tts_infer.yaml"
 
 
 
 
 
142
 
143
  tts_config = TTS_Config(config_path)
144
  print(tts_config)
145
- tts_pipeline = TTS(tts_config)
146
 
147
  APP = FastAPI()
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- class TTS_Request(BaseModel):
151
- text: str = None
152
- text_lang: str = None
153
- ref_audio_path: str = None
154
- aux_ref_audio_paths: list = None
155
- prompt_lang: str = None
156
- prompt_text: str = ""
157
- top_k: int = 5
158
- top_p: float = 1
159
- temperature: float = 1
160
- text_split_method: str = "cut5"
161
- batch_size: int = 1
162
- batch_threshold: float = 0.75
163
- split_bucket: bool = True
164
- speed_factor: float = 1.0
165
- fragment_interval: float = 0.3
166
- seed: int = -1
167
- media_type: str = "wav"
168
- streaming_mode: bool = False
169
- parallel_infer: bool = True
170
- repetition_penalty: float = 1.35
171
- sample_steps: int = 32
172
- super_sampling: bool = False
173
 
 
 
 
174
 
175
- ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
176
- def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
177
- with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
178
- audio_file.write(data)
179
- return io_buffer
180
 
181
 
182
- def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
183
- io_buffer.write(data.tobytes())
184
- return io_buffer
185
 
186
 
187
- def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
188
- io_buffer = BytesIO()
189
- sf.write(io_buffer, data, rate, format="wav")
190
- return io_buffer
191
 
192
 
193
- def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
194
- process = subprocess.Popen(
195
  [
196
  "ffmpeg",
197
  "-f",
198
- "s16le", # 输入16位有符号小端整数PCM
199
  "-ar",
200
- str(rate), # 设置采样率
201
  "-ac",
202
- "1", # 单声道
203
  "-i",
204
- "pipe:0", # 从管道读取输入
205
  "-c:a",
206
- "aac", # 音频编码器为AAC
207
  "-b:a",
208
- "192k", # 比特率
209
- "-vn", # 不包含视频
210
  "-f",
211
- "adts", # 输出AAC数据流格式
212
- "pipe:1", # 将输出写入管道
213
  ],
214
  stdin=subprocess.PIPE,
215
  stdout=subprocess.PIPE,
216
  stderr=subprocess.PIPE,
217
  )
218
- out, _ = process.communicate(input=data.tobytes())
219
- io_buffer.write(out)
220
- return io_buffer
221
-
222
-
223
- def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
224
- if media_type == "ogg":
225
- io_buffer = pack_ogg(io_buffer, data, rate)
226
- elif media_type == "aac":
227
- io_buffer = pack_aac(io_buffer, data, rate)
228
- elif media_type == "wav":
229
- io_buffer = pack_wav(io_buffer, data, rate)
230
- else:
231
- io_buffer = pack_raw(io_buffer, data, rate)
232
- io_buffer.seek(0)
233
- return io_buffer
234
 
235
 
236
- # from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
237
- def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
238
- # This will create a wave header then append the frame input
239
- # It should be first on a streaming wav file
240
- # Other frames better should not have it (else you will hear some artifacts each chunk start)
241
- wav_buf = BytesIO()
242
- with wave.open(wav_buf, "wb") as vfout:
243
- vfout.setnchannels(channels)
244
- vfout.setsampwidth(sample_width)
245
- vfout.setframerate(sample_rate)
246
- vfout.writeframes(frame_input)
247
 
248
- wav_buf.seek(0)
249
- return wav_buf.read()
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- def handle_control(command: str):
253
- if command == "restart":
254
- os.execl(sys.executable, sys.executable, *argv)
255
- elif command == "exit":
256
- os.kill(os.getpid(), signal.SIGTERM)
257
- exit(0)
258
-
259
-
260
- def check_params(req: dict):
261
- text: str = req.get("text", "")
262
- text_lang: str = req.get("text_lang", "")
263
- ref_audio_path: str = req.get("ref_audio_path", "")
264
- streaming_mode: bool = req.get("streaming_mode", False)
265
- media_type: str = req.get("media_type", "wav")
266
- prompt_lang: str = req.get("prompt_lang", "")
267
- text_split_method: str = req.get("text_split_method", "cut5")
268
-
269
- if ref_audio_path in [None, ""]:
270
- return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
271
- if text in [None, ""]:
272
- return JSONResponse(status_code=400, content={"message": "text is required"})
273
- if text_lang in [None, ""]:
274
- return JSONResponse(status_code=400, content={"message": "text_lang is required"})
275
- elif text_lang.lower() not in tts_config.languages:
276
- return JSONResponse(
277
- status_code=400,
278
- content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"},
279
- )
280
- if prompt_lang in [None, ""]:
281
- return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
282
- elif prompt_lang.lower() not in tts_config.languages:
283
- return JSONResponse(
284
- status_code=400,
285
- content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"},
286
- )
287
- if media_type not in ["wav", "raw", "ogg", "aac"]:
288
- return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
289
- elif media_type == "ogg" and not streaming_mode:
290
- return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
291
-
292
- if text_split_method not in cut_method_names:
293
- return JSONResponse(
294
- status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}
295
- )
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  return None
298
 
299
 
300
- async def tts_handle(req: dict):
301
- """
302
- Text to speech handler.
303
-
304
- Args:
305
- req (dict):
306
- {
307
- "text": "", # str.(required) text to be synthesized
308
- "text_lang: "", # str.(required) language of the text to be synthesized
309
- "ref_audio_path": "", # str.(required) reference audio path
310
- "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
311
- "prompt_text": "", # str.(optional) prompt text for the reference audio
312
- "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
313
- "top_k": 5, # int. top k sampling
314
- "top_p": 1, # float. top p sampling
315
- "temperature": 1, # float. temperature for sampling
316
- "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
317
- "batch_size": 1, # int. batch size for inference
318
- "batch_threshold": 0.75, # float. threshold for batch splitting.
319
- "split_bucket: True, # bool. whether to split the batch into multiple buckets.
320
- "speed_factor":1.0, # float. control the speed of the synthesized audio.
321
- "fragment_interval":0.3, # float. to control the interval of the audio fragment.
322
- "seed": -1, # int. random seed for reproducibility.
323
- "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
324
- "streaming_mode": False, # bool. whether to return a streaming response.
325
- "parallel_infer": True, # bool.(optional) whether to use parallel inference.
326
- "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
327
- "sample_steps": 32, # int. number of sampling steps for VITS model V3.
328
- "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
329
- }
330
- returns:
331
- StreamingResponse: audio stream response.
332
- """
333
 
334
  streaming_mode = req.get("streaming_mode", False)
335
- return_fragment = req.get("return_fragment", False)
336
  media_type = req.get("media_type", "wav")
337
 
338
- check_res = check_params(req)
339
- if check_res is not None:
340
- return check_res
341
-
342
- if streaming_mode or return_fragment:
343
- req["return_fragment"] = True
344
-
345
  try:
346
- tts_generator = tts_pipeline.run(req)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  if streaming_mode:
349
-
350
- def streaming_generator(tts_generator: Generator, media_type: str):
351
- if_frist_chunk = True
352
- for sr, chunk in tts_generator:
353
- if if_frist_chunk and media_type == "wav":
354
- yield wave_header_chunk(sample_rate=sr)
355
- media_type = "raw"
356
- if_frist_chunk = False
357
- yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
358
-
359
- # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
360
- return StreamingResponse(
361
- streaming_generator(
362
- tts_generator,
363
- media_type,
364
- ),
365
- media_type=f"audio/{media_type}",
366
- )
367
-
368
- else:
369
- sr, audio_data = next(tts_generator)
370
- audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
371
- return Response(audio_data, media_type=f"audio/{media_type}")
372
- except Exception as e:
373
- return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
374
-
375
-
376
- @APP.get("/control")
377
- async def control(command: str = None):
378
- if command is None:
379
- return JSONResponse(status_code=400, content={"message": "command is required"})
380
- handle_control(command)
381
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  @APP.get("/tts")
384
- async def tts_get_endpoint(
385
- text: str = None,
386
- text_lang: str = None,
387
- ref_audio_path: str = None,
388
- aux_ref_audio_paths: list = None,
389
- prompt_lang: str = None,
390
- prompt_text: str = "",
391
- top_k: int = 5,
392
- top_p: float = 1,
393
- temperature: float = 1,
394
- text_split_method: str = "cut0",
395
- batch_size: int = 1,
396
- batch_threshold: float = 0.75,
397
- split_bucket: bool = True,
398
- speed_factor: float = 1.0,
399
- fragment_interval: float = 0.3,
400
- seed: int = -1,
401
- media_type: str = "wav",
402
- streaming_mode: bool = False,
403
- parallel_infer: bool = True,
404
- repetition_penalty: float = 1.35,
405
- sample_steps: int = 32,
406
- super_sampling: bool = False,
407
- ):
408
- req = {
409
- "text": text,
410
- "text_lang": text_lang.lower(),
411
- "ref_audio_path": ref_audio_path,
412
- "aux_ref_audio_paths": aux_ref_audio_paths,
413
- "prompt_text": prompt_text,
414
- "prompt_lang": prompt_lang.lower(),
415
- "top_k": top_k,
416
- "top_p": top_p,
417
- "temperature": temperature,
418
- "text_split_method": text_split_method,
419
- "batch_size": int(batch_size),
420
- "batch_threshold": float(batch_threshold),
421
- "speed_factor": float(speed_factor),
422
- "split_bucket": split_bucket,
423
- "fragment_interval": fragment_interval,
424
- "seed": seed,
425
- "media_type": media_type,
426
- "streaming_mode": streaming_mode,
427
- "parallel_infer": parallel_infer,
428
- "repetition_penalty": float(repetition_penalty),
429
- "sample_steps": int(sample_steps),
430
- "super_sampling": super_sampling,
431
- }
432
- return await tts_handle(req)
433
 
434
 
435
  @APP.post("/tts")
436
- async def tts_post_endpoint(request: TTS_Request):
437
- req = request.dict()
438
- return await tts_handle(req)
 
 
 
 
439
 
440
 
441
- @APP.get("/set_refer_audio")
442
- async def set_refer_aduio(refer_audio_path: str = None):
443
- try:
444
- tts_pipeline.set_ref_audio(refer_audio_path)
445
- except Exception as e:
446
- return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
447
- return JSONResponse(status_code=200, content={"message": "success"})
448
-
 
 
 
449
 
450
- # @APP.post("/set_refer_audio")
451
- # async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
452
- # try:
453
- # # 检查文件类型,确保是音频文件
454
- # if not audio_file.content_type.startswith("audio/"):
455
- # return JSONResponse(status_code=400, content={"message": "file type is not supported"})
456
 
457
- # os.makedirs("uploaded_audio", exist_ok=True)
458
- # save_path = os.path.join("uploaded_audio", audio_file.filename)
459
- # # 保存音频文件到服务器上的一个目录
460
- # with open(save_path , "wb") as buffer:
461
- # buffer.write(await audio_file.read())
462
 
463
- # tts_pipeline.set_ref_audio(save_path)
464
- # except Exception as e:
465
- # return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
466
- # return JSONResponse(status_code=200, content={"message": "success"})
 
 
 
 
 
 
 
 
 
467
 
468
 
469
  @APP.get("/set_gpt_weights")
470
- async def set_gpt_weights(weights_path: str = None):
 
 
471
  try:
472
- if weights_path in ["", None]:
473
- return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
474
- tts_pipeline.init_t2s_weights(weights_path)
475
- except Exception as e:
476
- return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
477
-
478
- return JSONResponse(status_code=200, content={"message": "success"})
479
 
480
 
481
  @APP.get("/set_sovits_weights")
482
- async def set_sovits_weights(weights_path: str = None):
 
 
483
  try:
484
- if weights_path in ["", None]:
485
- return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
486
- tts_pipeline.init_vits_weights(weights_path)
487
- except Exception as e:
488
- return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
489
- return JSONResponse(status_code=200, content={"message": "success"})
490
 
 
 
 
491
 
492
  if __name__ == "__main__":
493
  try:
494
- if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈
495
- host = None
496
- uvicorn.run(app=APP, host=host, port=port, workers=1)
497
- except Exception:
498
  traceback.print_exc()
499
  os.kill(os.getpid(), signal.SIGTERM)
500
- exit(0)
 
1
+ # -*- coding: utf-8 -*-
2
  """
3
+ Updated FastAPI backend for GPT-SoVITS (*April 2025*)
4
+ ---------------------------------------------------
5
+ Changes compared with the previous version shipped on 30 Apr 2025
6
+ =================================================================
7
+ 1. **URL / S3 audio support** — `process_audio_path()` downloads `ref_audio_path` and
8
+ each entry in `aux_ref_audio_paths` when they are HTTP(S) or S3 URLs, storing them
9
+ as temporary files that are cleaned up afterwards.
10
+ 2. **CUDA memory hygiene** — `torch.cuda.empty_cache()` is invoked after each request
11
+ (success *or* error) to release GPU memory.
12
+ 3. **Temporary‑file cleanup** — all files created by `process_audio_path()` are
13
+ removed in `finally` blocks so they are guaranteed to disappear no matter how the
14
+ request terminates.
15
+
16
+ The public API surface (**end‑points and query parameters**) is unchanged.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
 
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
  import os
23
+ import signal
24
+ import subprocess
25
  import sys
26
  import traceback
27
+ import urllib.parse
28
+ from io import BytesIO
29
+ from typing import Generator, List, Tuple
 
 
30
 
 
 
 
 
31
  import numpy as np
32
+ import requests
33
  import soundfile as sf
34
+ import torch
 
35
  import uvicorn
36
+ from fastapi import FastAPI, HTTPException, Response
37
+ from fastapi.responses import JSONResponse, StreamingResponse
 
 
38
  from pydantic import BaseModel
39
 
40
+ # ---------------------------------------------------------------------------
41
+ # Local package imports – keep *after* sys.path manipulation so relative import
42
+ # resolution continues to work when this file is executed from any directory.
43
+ # ---------------------------------------------------------------------------
44
+
45
+ NOW_DIR = os.getcwd()
46
+ sys.path.extend([NOW_DIR, f"{NOW_DIR}/GPT_SoVITS"])
47
+
48
+ from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
49
+ from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( # noqa: E402
50
+ get_method_names as get_cut_method_names,
51
+ )
52
+ from tools.i18n.i18n import I18nAuto # noqa: E402
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # CLI arguments & global objects
56
+ # ---------------------------------------------------------------------------
57
+
58
  i18n = I18nAuto()
59
  cut_method_names = get_cut_method_names()
60
 
61
+ parser = argparse.ArgumentParser(description="GPTSoVITS API")
62
+ parser.add_argument(
63
+ "-c", "--tts_config", default="GPT_SoVITS/configs/tts_infer.yaml", help="TTS‑infer config path"
64
+ )
65
+ parser.add_argument("-a", "--bind_addr", default="127.0.0.1", help="Bind address (default 127.0.0.1)")
66
+ parser.add_argument("-p", "--port", type=int, default=9880, help="Port (default 9880)")
67
  args = parser.parse_args()
 
 
 
 
 
68
 
69
+ config_path = args.tts_config or "GPT-SoVITS/configs/tts_infer.yaml"
70
+ PORT = args.port
71
+ HOST = None if args.bind_addr == "None" else args.bind_addr
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # TTS initialisation
75
+ # ---------------------------------------------------------------------------
76
 
77
  tts_config = TTS_Config(config_path)
78
  print(tts_config)
79
+ TTS_PIPELINE = TTS(tts_config)
80
 
81
  APP = FastAPI()
82
 
83
+ # ---------------------------------------------------------------------------
84
+ # Helper utilities
85
+ # ---------------------------------------------------------------------------
86
+
87
+ TEMP_DIR = os.path.join(NOW_DIR, "_tmp_audio")
88
+ os.makedirs(TEMP_DIR, exist_ok=True)
89
+
90
+ def _empty_cuda_cache() -> None:
91
+ """Release GPU memory if CUDA is available."""
92
+ if torch.cuda.is_available():
93
+ torch.cuda.empty_cache()
94
+
95
+
96
+ def _download_to_temp(url: str) -> str:
97
+ """Download *url* to a unique file inside ``TEMP_DIR`` and return the path."""
98
+ parsed = urllib.parse.urlparse(url)
99
+ filename = os.path.basename(parsed.path) or f"audio_{abs(hash(url))}.wav"
100
+ local_path = os.path.join(TEMP_DIR, filename)
101
+
102
+ if url.startswith("s3://"):
103
+ # Lazy‑load boto3 if/when the first S3 request arrives.
104
+ import importlib
105
+
106
+ boto3 = importlib.import_module("boto3") # pylint: disable=import-error
107
+ s3_client = boto3.client("s3")
108
+ s3_client.download_file(parsed.netloc, parsed.path.lstrip("/"), local_path)
109
+ else:
110
+ with requests.get(url, stream=True, timeout=30) as r:
111
+ r.raise_for_status()
112
+ with open(local_path, "wb") as f_out:
113
+ for chunk in r.iter_content(chunk_size=8192):
114
+ f_out.write(chunk)
115
+
116
+ return local_path
117
+
118
+
119
+ def process_audio_path(audio_path: str | None) -> Tuple[str | None, bool]:
120
+ """Return a *local* path for *audio_path* and whether it is temporary."""
121
+ if not audio_path:
122
+ return audio_path, False
123
+
124
+ if audio_path.startswith(("http://", "https://", "s3://")):
125
+ try:
126
+ local = _download_to_temp(audio_path)
127
+ return local, True
128
+ except Exception as exc: # pragma: no‑cover
129
+ raise HTTPException(status_code=400, detail=f"Failed to download audio: {exc}") from exc
130
+ return audio_path, False
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # ---------------------------------------------------------------------------
134
+ # Audio (de)serialisation helpers
135
+ # ---------------------------------------------------------------------------
136
 
137
+ def _pack_ogg(buf: BytesIO, data: np.ndarray, rate: int):
138
+ with sf.SoundFile(buf, mode="w", samplerate=rate, channels=1, format="ogg") as f:
139
+ f.write(data)
140
+ return buf
 
141
 
142
 
143
+ def _pack_raw(buf: BytesIO, data: np.ndarray, _rate: int):
144
+ buf.write(data.tobytes())
145
+ return buf
146
 
147
 
148
+ def _pack_wav(buf: BytesIO, data: np.ndarray, rate: int):
149
+ sf.write(buf, data, rate, format="wav")
150
+ return buf
 
151
 
152
 
153
+ def _pack_aac(buf: BytesIO, data: np.ndarray, rate: int):
154
+ proc = subprocess.Popen(
155
  [
156
  "ffmpeg",
157
  "-f",
158
+ "s16le",
159
  "-ar",
160
+ str(rate),
161
  "-ac",
162
+ "1",
163
  "-i",
164
+ "pipe:0",
165
  "-c:a",
166
+ "aac",
167
  "-b:a",
168
+ "192k",
169
+ "-vn",
170
  "-f",
171
+ "adts",
172
+ "pipe:1",
173
  ],
174
  stdin=subprocess.PIPE,
175
  stdout=subprocess.PIPE,
176
  stderr=subprocess.PIPE,
177
  )
178
+ out, _ = proc.communicate(input=data.tobytes())
179
+ buf.write(out)
180
+ return buf
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
 
183
+ def _pack_audio(buf: BytesIO, data: np.ndarray, rate: int, media_type: str):
184
+ dispatch = {
185
+ "ogg": _pack_ogg,
186
+ "aac": _pack_aac,
187
+ "wav": _pack_wav,
188
+ "raw": _pack_raw,
189
+ }
190
+ buf = dispatch.get(media_type, _pack_raw)(buf, data, rate)
191
+ buf.seek(0)
192
+ return buf
 
193
 
 
 
194
 
195
+ # ---------------------------------------------------------------------------
196
+ # Schemas
197
+ # ---------------------------------------------------------------------------
198
+
199
+ class TTSRequest(BaseModel):
200
+ text: str | None = None
201
+ text_lang: str | None = None
202
+ ref_audio_path: str | None = None
203
+ aux_ref_audio_paths: List[str] | None = None
204
+ prompt_lang: str | None = None
205
+ prompt_text: str = ""
206
+ top_k: int = 5
207
+ top_p: float = 1.0
208
+ temperature: float = 1.0
209
+ text_split_method: str = "cut5"
210
+ batch_size: int = 1
211
+ batch_threshold: float = 0.75
212
+ split_bucket: bool = True
213
+ speed_factor: float = 1.0
214
+ fragment_interval: float = 0.3
215
+ seed: int = -1
216
+ media_type: str = "wav"
217
+ streaming_mode: bool = False
218
+ parallel_infer: bool = True
219
+ repetition_penalty: float = 1.35
220
+ sample_steps: int = 32
221
+ super_sampling: bool = False
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ # ---------------------------------------------------------------------------
225
+ # Validation helpers
226
+ # ---------------------------------------------------------------------------
227
+
228
+ def _validate_request(req: dict):
229
+ if not req.get("text"):
230
+ return "text is required"
231
+ if not req.get("text_lang"):
232
+ return "text_lang is required"
233
+ if req["text_lang"].lower() not in tts_config.languages:
234
+ return f"text_lang {req['text_lang']} not supported"
235
+ if not req.get("prompt_lang"):
236
+ return "prompt_lang is required"
237
+ if req["prompt_lang"].lower() not in tts_config.languages:
238
+ return f"prompt_lang {req['prompt_lang']} not supported"
239
+ if not req.get("ref_audio_path"):
240
+ return "ref_audio_path is required"
241
+ mt = req.get("media_type", "wav")
242
+ if mt not in {"wav", "raw", "ogg", "aac"}:
243
+ return f"media_type {mt} not supported"
244
+ if (not req.get("streaming_mode") and mt == "ogg"):
245
+ return "ogg is only supported in streaming mode"
246
+ if req.get("text_split_method", "cut5") not in cut_method_names:
247
+ return f"text_split_method {req['text_split_method']} not supported"
248
  return None
249
 
250
 
251
+ # ---------------------------------------------------------------------------
252
+ # Core handler
253
+ # ---------------------------------------------------------------------------
254
+
255
+ async def _tts_handle(req: dict):
256
+ error = _validate_request(req)
257
+ if error:
258
+ return JSONResponse(status_code=400, content={"message": error})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  streaming_mode = req.get("streaming_mode", False)
 
261
  media_type = req.get("media_type", "wav")
262
 
263
+ temp_files: List[str] = []
 
 
 
 
 
 
264
  try:
265
+ # --- resolve & download audio paths ----------------------------------
266
+ ref_path, is_tmp = process_audio_path(req["ref_audio_path"])
267
+ req["ref_audio_path"] = ref_path
268
+ if is_tmp:
269
+ temp_files.append(ref_path)
270
+
271
+ if req.get("aux_ref_audio_paths"):
272
+ resolved: List[str] = []
273
+ for p in req["aux_ref_audio_paths"]:
274
+ lp, tmp = process_audio_path(p)
275
+ resolved.append(lp)
276
+ if tmp:
277
+ temp_files.append(lp)
278
+ req["aux_ref_audio_paths"] = resolved
279
+
280
+ # --- run inference ----------------------------------------------------
281
+ generator = TTS_PIPELINE.run(req)
282
 
283
  if streaming_mode:
284
+ async def _gen(gen: Generator, _media_type: str):
285
+ first = True
286
+ try:
287
+ for sr, chunk in gen:
288
+ if first and _media_type == "wav":
289
+ # Prepend a WAV header so clients can play immediately.
290
+ header = _wave_header_chunk(sample_rate=sr)
291
+ yield header
292
+ _media_type = "raw"
293
+ first = False
294
+ yield _pack_audio(BytesIO(), chunk, sr, _media_type).getvalue()
295
+ finally:
296
+ _cleanup(temp_files)
297
+ return StreamingResponse(_gen(generator, media_type), media_type=f"audio/{media_type}")
298
+
299
+ # non‑streaming --------------------------------------------------------
300
+ sr, data = next(generator)
301
+ payload = _pack_audio(BytesIO(), data, sr, media_type).getvalue()
302
+ resp = Response(payload, media_type=f"audio/{media_type}")
303
+ _cleanup(temp_files)
304
+ return resp
305
+
306
+ except Exception as exc: # noqa: BLE001
307
+ _cleanup(temp_files)
308
+ return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(exc)})
309
+
310
+
311
+ # ---------------------------------------------------------------------------
312
+ # Cleanup helpers
313
+ # ---------------------------------------------------------------------------
314
+
315
+ def _cleanup(temp_files: List[str]):
316
+ for fp in temp_files:
317
+ try:
318
+ os.remove(fp)
319
+ # print(f"[cleanup] removed {fp}")
320
+ except FileNotFoundError:
321
+ pass
322
+ except Exception as exc: # pragma: no‑cover
323
+ print(f"[cleanup‑warning] {exc}")
324
+ _empty_cuda_cache()
325
+
326
+
327
+ # ---------------------------------------------------------------------------
328
+ # WAV header helper (for streaming WAV)
329
+ # ---------------------------------------------------------------------------
330
+
331
+ import wave # placed here to keep top import section tidy
332
+
333
+ def _wave_header_chunk(frame: bytes = b"", *, channels: int = 1, width: int = 2, sample_rate: int = 32_000):
334
+ buf = BytesIO()
335
+ with wave.open(buf, "wb") as wav:
336
+ wav.setnchannels(channels)
337
+ wav.setsampwidth(width)
338
+ wav.setframerate(sample_rate)
339
+ wav.writeframes(frame)
340
+ buf.seek(0)
341
+ return buf.read()
342
+
343
+
344
+ # ---------------------------------------------------------------------------
345
+ # End‑points
346
+ # ---------------------------------------------------------------------------
347
 
348
  @APP.get("/tts")
349
+ async def tts_get(**query):
350
+ # Normalise language codes to lower‑case where applicable
351
+ for k in ("text_lang", "prompt_lang"):
352
+ if k in query and query[k] is not None:
353
+ query[k] = query[k].lower()
354
+ return await _tts_handle(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
  @APP.post("/tts")
358
+ async def tts_post(request: TTSRequest):
359
+ payload = request.dict()
360
+ if payload.get("text_lang"):
361
+ payload["text_lang"] = payload["text_lang"].lower()
362
+ if payload.get("prompt_lang"):
363
+ payload["prompt_lang"] = payload["prompt_lang"].lower()
364
+ return await _tts_handle(payload)
365
 
366
 
367
+ @APP.get("/control")
368
+ async def control(command: str | None = None):
369
+ if not command:
370
+ raise HTTPException(status_code=400, detail="command is required")
371
+ if command == "restart":
372
+ os.execl(sys.executable, sys.executable, *sys.argv)
373
+ elif command == "exit":
374
+ os.kill(os.getpid(), signal.SIGTERM)
375
+ else:
376
+ raise HTTPException(status_code=400, detail="unsupported command")
377
+ return {"message": "ok"}
378
 
 
 
 
 
 
 
379
 
380
+ @APP.get("/set_refer_audio")
381
+ async def set_refer_audio(refer_audio_path: str | None = None):
382
+ if not refer_audio_path:
383
+ return JSONResponse(status_code=400, content={"message": "refer_audio_path is required"})
 
384
 
385
+ temp_file = None
386
+ try:
387
+ local_path, is_tmp = process_audio_path(refer_audio_path)
388
+ temp_file = local_path if is_tmp else None
389
+ TTS_PIPELINE.set_ref_audio(local_path)
390
+ return {"message": "success"}
391
+ finally:
392
+ if temp_file:
393
+ try:
394
+ os.remove(temp_file)
395
+ except FileNotFoundError:
396
+ pass
397
+ _empty_cuda_cache()
398
 
399
 
400
  @APP.get("/set_gpt_weights")
401
+ async def set_gpt_weights(weights_path: str | None = None):
402
+ if not weights_path:
403
+ return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
404
  try:
405
+ TTS_PIPELINE.init_t2s_weights(weights_path)
406
+ return {"message": "success"}
407
+ except Exception as exc: # noqa: BLE001
408
+ return JSONResponse(status_code=400, content={"message": str(exc)})
 
 
 
409
 
410
 
411
  @APP.get("/set_sovits_weights")
412
+ async def set_sovits_weights(weights_path: str | None = None):
413
+ if not weights_path:
414
+ return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
415
  try:
416
+ TTS_PIPELINE.init_vits_weights(weights_path)
417
+ return {"message": "success"}
418
+ except Exception as exc: # noqa: BLE001
419
+ return JSONResponse(status_code=400, content={"message": str(exc)})
420
+
 
421
 
422
+ # ---------------------------------------------------------------------------
423
+ # Main entry point
424
+ # ---------------------------------------------------------------------------
425
 
426
  if __name__ == "__main__":
427
  try:
428
+ uvicorn.run(app=APP, host=HOST, port=PORT, workers=1)
429
+ except Exception: # pragma: no‑cover
 
 
430
  traceback.print_exc()
431
  os.kill(os.getpid(), signal.SIGTERM)
432
+ sys.exit(0)