|
from fastapi import FastAPI, HTTPException, Query |
|
from fastapi.responses import StreamingResponse |
|
import os |
|
from os import environ as env |
|
import torch |
|
import time |
|
import nltk |
|
import io |
|
import base64 |
|
import torchaudio |
|
from fastapi.responses import JSONResponse |
|
from app.inference import inference, LFinference, compute_style |
|
import numpy as np |
|
|
|
nltk.download('punkt') |
|
nltk.download('punkt_tab') |
|
|
|
app = FastAPI() |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
@app.get("/") |
|
async def read_root(): |
|
|
|
|
|
|
|
return {"details": "Environment is running OK!"} |
|
|
|
@app.post("/synthesize/") |
|
async def synthesize( |
|
text: str, |
|
return_base64: bool = True, |
|
|
|
diffusion_steps: int = Query(5, ge=5, le=200), |
|
embedding_scale: float = Query(1.0, ge=1.0, le=5.0) |
|
|
|
): |
|
try: |
|
start = time.time() |
|
noise = torch.randn(1, 1, 256).to(device) |
|
wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale) |
|
rtf = (time.time() - start) / (len(wav) / 24000) |
|
|
|
if return_base64: |
|
audio_buffer = io.BytesIO() |
|
torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") |
|
audio_buffer.seek(0) |
|
|
|
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') |
|
|
|
return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) |
|
else: |
|
return JSONResponse(content={"RTF": rtf, "audio": wav.tolist()}) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/synthesize_longform_streaming/") |
|
async def synthesize_longform( |
|
passage: str, |
|
return_base64: bool = False, |
|
|
|
alpha: float = Query(0.7, ge=0.0, le=1.0), |
|
diffusion_steps: int = Query(10, ge=5, le=200), |
|
embedding_scale: float = Query(1.5, ge=1.0, le=5.0) |
|
|
|
): |
|
try: |
|
sentences = passage.split('.') |
|
wavs = [] |
|
s_prev = None |
|
|
|
start = time.time() |
|
|
|
for text in sentences: |
|
if text.strip() == "": |
|
continue |
|
text += '.' |
|
noise = torch.randn(1, 1, 256).to(device) |
|
wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, |
|
diffusion_steps=diffusion_steps, |
|
embedding_scale=embedding_scale) |
|
wavs.append(wav) |
|
|
|
final_wav = np.concatenate(wavs) |
|
rtf = (time.time() - start) / (len(final_wav) / 24000) |
|
|
|
audio_buffer = io.BytesIO() |
|
torchaudio.save(audio_buffer, torch.tensor(final_wav).unsqueeze(0), 24000, format="wav") |
|
audio_buffer.seek(0) |
|
|
|
if return_base64: |
|
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') |
|
return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) |
|
else: |
|
|
|
return StreamingResponse(audio_buffer, media_type="audio/wav") |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
@app.post("/synthesize_with_emotion/") |
|
async def synthesize_with_emotion( |
|
texts: dict, |
|
return_base64: bool = True, |
|
|
|
diffusion_steps: int = Query(100, ge=5, le=200), |
|
embedding_scale: float = Query(5.0, ge=1.0, le=5.0) |
|
|
|
): |
|
try: |
|
results = [] |
|
|
|
for emotion, text in texts.items(): |
|
noise = torch.randn(1, 1, 256).to(device) |
|
wav = inference(text, noise, diffusion_steps=diffusion_steps, |
|
embedding_scale=embedding_scale) |
|
|
|
if return_base64: |
|
audio_buffer = io.BytesIO() |
|
torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") |
|
audio_buffer.seek(0) |
|
|
|
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') |
|
|
|
results.append({ |
|
"emotion": emotion, |
|
"audio_base64": audio_base64 |
|
}) |
|
else: |
|
results.append({ |
|
"emotion": emotion, |
|
"audio": wav.tolist() |
|
}) |
|
|
|
return JSONResponse(content={"results": results}) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/synthesize_streaming_audio/") |
|
async def synthesize_streaming_audio( |
|
text: str, |
|
return_base64: bool = False, |
|
|
|
diffusion_steps: int = Query(5, ge=5, le=200), |
|
embedding_scale: float = Query(1.0, ge=1.0, le=5.0) |
|
|
|
): |
|
try: |
|
start = time.time() |
|
noise = torch.randn(1, 1, 256).to(device) |
|
wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale) |
|
rtf = (time.time() - start) / (len(wav) / 24000) |
|
|
|
audio_buffer = io.BytesIO() |
|
torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") |
|
audio_buffer.seek(0) |
|
|
|
if return_base64: |
|
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') |
|
return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) |
|
else: |
|
return StreamingResponse(audio_buffer, media_type="audio/wav") |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|