Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.responses import StreamingResponse | |
import os | |
from os import environ as env | |
import torch | |
import time | |
import nltk | |
import io | |
import base64 | |
import torchaudio | |
from fastapi.responses import JSONResponse | |
from app.inference import inference, LFinference, compute_style | |
import numpy as np | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
app = FastAPI() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
async def read_root(): | |
#return {"details": f"Hello! This is {env['SECRET_API_KEY']} environment"} | |
#return {"details": f"Hello Stream!"} | |
#return {"details": f"Hello Stream! This is {env['API_KEY_SECRET']} environment running OK!"} | |
return {"details": "Environment is running OK!"} | |
async def synthesize( | |
text: str, | |
return_base64: bool = True, | |
################################################### | |
diffusion_steps: int = Query(5, ge=5, le=200), | |
embedding_scale: float = Query(1.0, ge=1.0, le=5.0) | |
################################################### | |
): | |
try: | |
start = time.time() | |
noise = torch.randn(1, 1, 256).to(device) | |
wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale) | |
rtf = (time.time() - start) / (len(wav) / 24000) | |
if return_base64: | |
audio_buffer = io.BytesIO() | |
torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") | |
audio_buffer.seek(0) | |
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) | |
else: | |
return JSONResponse(content={"RTF": rtf, "audio": wav.tolist()}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def synthesize_longform( | |
passage: str, | |
return_base64: bool = False, | |
################################################### | |
alpha: float = Query(0.7, ge=0.0, le=1.0), | |
diffusion_steps: int = Query(10, ge=5, le=200), | |
embedding_scale: float = Query(1.5, ge=1.0, le=5.0) | |
################################################### | |
): | |
try: | |
sentences = passage.split('.') # simple split | |
wavs = [] | |
s_prev = None | |
start = time.time() | |
for text in sentences: | |
if text.strip() == "": | |
continue | |
text += '.' # add it back | |
noise = torch.randn(1, 1, 256).to(device) # Generate noise | |
wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, | |
diffusion_steps=diffusion_steps, | |
embedding_scale=embedding_scale) | |
wavs.append(wav) | |
final_wav = np.concatenate(wavs) # Concatenate all wavs | |
rtf = (time.time() - start) / (len(final_wav) / 24000) | |
audio_buffer = io.BytesIO() | |
torchaudio.save(audio_buffer, torch.tensor(final_wav).unsqueeze(0), 24000, format="wav") | |
audio_buffer.seek(0) | |
if return_base64: | |
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) | |
else: | |
#return JSONResponse(content={"RTF": rtf, "audio": final_wav.tolist()}) | |
return StreamingResponse(audio_buffer, media_type="audio/wav") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def synthesize_with_emotion( | |
texts: dict, | |
return_base64: bool = True, | |
################################################### | |
diffusion_steps: int = Query(100, ge=5, le=200), | |
embedding_scale: float = Query(5.0, ge=1.0, le=5.0) | |
################################################### | |
): | |
try: | |
results = [] | |
for emotion, text in texts.items(): | |
noise = torch.randn(1, 1, 256).to(device) | |
wav = inference(text, noise, diffusion_steps=diffusion_steps, | |
embedding_scale=embedding_scale) | |
if return_base64: | |
audio_buffer = io.BytesIO() | |
torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") | |
audio_buffer.seek(0) | |
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
results.append({ | |
"emotion": emotion, | |
"audio_base64": audio_base64 | |
}) | |
else: | |
results.append({ | |
"emotion": emotion, | |
"audio": wav.tolist() | |
}) | |
return JSONResponse(content={"results": results}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def synthesize_streaming_audio( | |
text: str, | |
return_base64: bool = False, | |
################################################### | |
diffusion_steps: int = Query(5, ge=5, le=200), | |
embedding_scale: float = Query(1.0, ge=1.0, le=5.0) | |
################################################### | |
): | |
try: | |
start = time.time() | |
noise = torch.randn(1, 1, 256).to(device) | |
wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale) | |
rtf = (time.time() - start) / (len(wav) / 24000) | |
audio_buffer = io.BytesIO() | |
torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") | |
audio_buffer.seek(0) | |
if return_base64: | |
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) | |
else: | |
return StreamingResponse(audio_buffer, media_type="audio/wav") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |