ChatTTS-Forge / modules /api /impl /xtts_v2_api.py
zhzluke96
update
bed01bd
import logging
from fastapi import HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from modules.api.Api import APIManager
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import speaker_mgr
logger = logging.getLogger(__name__)
class XTTS_V2_Settings:
def __init__(self):
self.stream_chunk_size = 100
self.temperature = 0.3
self.speed = 1
# TODO: 这两个参数现在用不着...但是其实gpt是可以用的可以考虑增加
self.length_penalty = 0.5
self.repetition_penalty = 1.0
self.top_p = 0.7
self.top_k = 20
self.enable_text_splitting = True
# 下面是额外配置 xtts_v2 中不包含的,但是本系统需要的
self.batch_size = 4
self.eos = "[uv_break]"
self.infer_seed = 42
self.use_decoder = True
self.prompt1 = ""
self.prompt2 = ""
self.prefix = ""
self.spliter_threshold = 100
self.style = ""
class TTSSettingsRequest(BaseModel):
# 这个 stream_chunk 现在当作 spliter_threshold 用
stream_chunk_size: int
temperature: float
speed: float
length_penalty: float
repetition_penalty: float
top_p: float
top_k: int
enable_text_splitting: bool
batch_size: int = None
eos: str = None
infer_seed: int = None
use_decoder: bool = None
prompt1: str = None
prompt2: str = None
prefix: str = None
spliter_threshold: int = None
style: str = None
class SynthesisRequest(BaseModel):
text: str
speaker_wav: str
language: str
def setup(app: APIManager):
XTTSV2 = XTTS_V2_Settings()
@app.get("/v1/xtts_v2/speakers")
async def speakers():
spks = speaker_mgr.list_speakers()
return [
{
"name": spk.name,
"voice_id": spk.id,
# TODO: 也许可以放一个 "/v1/tts" 接口地址在这里
"preview_url": "",
}
for spk in spks
]
@app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse)
async def tts_to_audio(request: SynthesisRequest):
text = request.text
# speaker_wav 就是 speaker id 。。。
voice_id = request.speaker_wav
language = request.language
spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
voice_id
)
if spk is None:
raise HTTPException(status_code=400, detail="Invalid speaker id")
tts_config = ChatTTSConfig(
style=XTTSV2.style,
temperature=XTTSV2.temperature,
top_k=XTTSV2.top_k,
top_p=XTTSV2.top_p,
prefix=XTTSV2.prefix,
prompt1=XTTSV2.prompt1,
prompt2=XTTSV2.prompt2,
)
infer_config = InferConfig(
batch_size=XTTSV2.batch_size,
spliter_threshold=XTTSV2.spliter_threshold,
eos=XTTSV2.eos,
seed=XTTSV2.infer_seed,
)
adjust_config = AdjustConfig(
speed_rate=XTTSV2.speed,
)
# TODO: support enhancer
enhancer_config = EnhancerConfig(
# enabled=params.enhance or params.denoise or False,
# lambd=0.9 if params.denoise else 0.1,
)
handler = TTSHandler(
text_content=text,
spk=spk,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
)
buffer = handler.enqueue_to_buffer(AudioFormat.mp3)
return StreamingResponse(buffer, media_type="audio/mpeg")
@app.get("/v1/xtts_v2/tts_stream")
async def tts_stream(
request: Request,
text: str = Query(),
speaker_wav: str = Query(),
language: str = Query(),
):
# speaker_wav 就是 speaker id 。。。
voice_id = speaker_wav
spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
voice_id
)
if spk is None:
raise HTTPException(status_code=400, detail="Invalid speaker id")
tts_config = ChatTTSConfig(
style=XTTSV2.style,
temperature=XTTSV2.temperature,
top_k=XTTSV2.top_k,
top_p=XTTSV2.top_p,
prefix=XTTSV2.prefix,
prompt1=XTTSV2.prompt1,
prompt2=XTTSV2.prompt2,
)
infer_config = InferConfig(
batch_size=XTTSV2.batch_size,
spliter_threshold=XTTSV2.spliter_threshold,
eos=XTTSV2.eos,
seed=XTTSV2.infer_seed,
)
adjust_config = AdjustConfig(
speed_rate=XTTSV2.speed,
)
# TODO: support enhancer
enhancer_config = EnhancerConfig(
# enabled=params.enhance or params.denoise or False,
# lambd=0.9 if params.denoise else 0.1,
)
handler = TTSHandler(
text_content=text,
spk=spk,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
)
async def generator():
for chunk in handler.enqueue_to_stream(AudioFormat.mp3):
disconnected = await request.is_disconnected()
if disconnected:
break
yield chunk
return StreamingResponse(generator(), media_type="audio/mpeg")
@app.post("/v1/xtts_v2/set_tts_settings")
async def set_tts_settings(request: TTSSettingsRequest):
try:
if request.stream_chunk_size < 50:
raise HTTPException(
status_code=400, detail="stream_chunk_size must be greater than 0"
)
if request.temperature < 0:
raise HTTPException(
status_code=400, detail="temperature must be greater than 0"
)
if request.speed < 0:
raise HTTPException(
status_code=400, detail="speed must be greater than 0"
)
if request.length_penalty < 0:
raise HTTPException(
status_code=400, detail="length_penalty must be greater than 0"
)
if request.repetition_penalty < 0:
raise HTTPException(
status_code=400, detail="repetition_penalty must be greater than 0"
)
if request.top_p < 0:
raise HTTPException(
status_code=400, detail="top_p must be greater than 0"
)
if request.top_k < 0:
raise HTTPException(
status_code=400, detail="top_k must be greater than 0"
)
XTTSV2.stream_chunk_size = request.stream_chunk_size
XTTSV2.spliter_threshold = request.stream_chunk_size
XTTSV2.temperature = request.temperature
XTTSV2.speed = request.speed
XTTSV2.length_penalty = request.length_penalty
XTTSV2.repetition_penalty = request.repetition_penalty
XTTSV2.top_p = request.top_p
XTTSV2.top_k = request.top_k
XTTSV2.enable_text_splitting = request.enable_text_splitting
# TODO: checker
if request.batch_size:
XTTSV2.batch_size = request.batch_size
if request.eos:
XTTSV2.eos = request.eos
if request.infer_seed:
XTTSV2.infer_seed = request.infer_seed
if request.use_decoder:
XTTSV2.use_decoder = request.use_decoder
if request.prompt1:
XTTSV2.prompt1 = request.prompt1
if request.prompt2:
XTTSV2.prompt2 = request.prompt2
if request.prefix:
XTTSV2.prefix = request.prefix
if request.spliter_threshold:
XTTSV2.spliter_threshold = request.spliter_threshold
if request.style:
XTTSV2.style = request.style
return {"message": "Settings successfully applied"}
except Exception as e:
if isinstance(e, HTTPException):
raise e
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))