File size: 2,920 Bytes
c02bdcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import io
import os
import sys
import zipfile
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Optional
import ChatTTS
from tools.audio import pcm_arr_to_mp3_view
from tools.logger import get_logger
import torch
from pydantic import BaseModel
logger = get_logger("Command")
app = FastAPI()
@app.on_event("startup")
async def startup_event():
global chat
chat = ChatTTS.Chat(get_logger("ChatTTS"))
logger.info("Initializing ChatTTS...")
if chat.load():
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
sys.exit(1)
class ChatTTSParams(BaseModel):
text: list[str]
stream: bool = False
lang: Optional[str] = None
skip_refine_text: bool = False
refine_text_only: bool = False
use_decoder: bool = True
do_text_normalization: bool = True
do_homophone_replacement: bool = False
params_refine_text: ChatTTS.Chat.RefineTextParams
params_infer_code: ChatTTS.Chat.InferCodeParams
@app.post("/generate_voice")
async def generate_voice(params: ChatTTSParams):
logger.info("Text input: %s", str(params.text))
# audio seed
if params.params_infer_code.manual_seed is not None:
torch.manual_seed(params.params_infer_code.manual_seed)
params.params_infer_code.spk_emb = chat.sample_random_speaker()
# text seed for text refining
if params.params_refine_text:
text = chat.infer(
text=params.text, skip_refine_text=False, refine_text_only=True
)
logger.info(f"Refined text: {text}")
else:
# no text refining
text = params.text
logger.info("Use speaker:")
logger.info(params.params_infer_code.spk_emb)
logger.info("Start voice inference.")
wavs = chat.infer(
text=text,
stream=params.stream,
lang=params.lang,
skip_refine_text=params.skip_refine_text,
use_decoder=params.use_decoder,
do_text_normalization=params.do_text_normalization,
do_homophone_replacement=params.do_homophone_replacement,
params_infer_code=params.params_infer_code,
params_refine_text=params.params_refine_text,
)
logger.info("Inference completed.")
# zip all of the audio files together
buf = io.BytesIO()
with zipfile.ZipFile(
buf, "a", compression=zipfile.ZIP_DEFLATED, allowZip64=False
) as f:
for idx, wav in enumerate(wavs):
f.writestr(f"{idx}.mp3", pcm_arr_to_mp3_view(wav))
logger.info("Audio generation successful.")
buf.seek(0)
response = StreamingResponse(buf, media_type="application/zip")
response.headers["Content-Disposition"] = "attachment; filename=audio_files.zip"
return response
|