Spaces:
Running
on
T4
Running
on
T4
import json | |
import logging | |
import shlex | |
import subprocess | |
import tempfile | |
import warnings | |
from pathlib import Path | |
from typing import Optional | |
import fastapi | |
import fastapi.middleware.cors | |
import tyro | |
import uvicorn | |
from attr import dataclass | |
from fastapi import Request | |
from fastapi.responses import Response | |
from fam.llm.fast_inference import TTS | |
from fam.llm.utils import check_audio_file | |
logger = logging.getLogger(__name__) | |
## Setup FastAPI server. | |
app = fastapi.FastAPI() | |
class ServingConfig: | |
huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1" | |
"""Absolute path to the model directory.""" | |
temperature: float = 1.0 | |
"""Temperature for sampling applied to both models.""" | |
seed: int = 1337 | |
"""Random seed for sampling.""" | |
port: int = 58003 | |
# Singleton | |
class _GlobalState: | |
config: ServingConfig | |
tts: TTS | |
GlobalState = _GlobalState() | |
class TTSRequest: | |
text: str | |
speaker_ref_path: Optional[str] = None | |
guidance: float = 3.0 | |
top_p: float = 0.95 | |
top_k: Optional[int] = None | |
async def health_check(): | |
return {"status": "ok"} | |
async def text_to_speech(req: Request): | |
audiodata = await req.body() | |
payload = None | |
wav_out_path = None | |
try: | |
headers = req.headers | |
payload = headers["X-Payload"] | |
payload = json.loads(payload) | |
tts_req = TTSRequest(**payload) | |
with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp: | |
if tts_req.speaker_ref_path is None: | |
wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) | |
check_audio_file(wav_path) | |
else: | |
# TODO: fix | |
wav_path = tts_req.speaker_ref_path | |
if wav_path is None: | |
warnings.warn("Running without speaker reference") | |
assert tts_req.guidance is None | |
wav_out_path = GlobalState.tts.synthesise( | |
text=tts_req.text, | |
spk_ref_path=wav_path, | |
top_p=tts_req.top_p, | |
guidance_scale=tts_req.guidance, | |
) | |
with open(wav_out_path, "rb") as f: | |
return Response(content=f.read(), media_type="audio/wav") | |
except Exception as e: | |
# traceback_str = "".join(traceback.format_tb(e.__traceback__)) | |
logger.exception(f"Error processing request {payload}") | |
return Response( | |
content="Something went wrong. Please try again in a few mins or contact us on Discord", | |
status_code=500, | |
) | |
finally: | |
if wav_out_path is not None: | |
Path(wav_out_path).unlink(missing_ok=True) | |
def _convert_audiodata_to_wav_path(audiodata, wav_tmp): | |
with tempfile.NamedTemporaryFile() as unknown_format_tmp: | |
if unknown_format_tmp.write(audiodata) == 0: | |
return None | |
unknown_format_tmp.flush() | |
subprocess.check_output( | |
# arbitrary 2 minute cutoff | |
shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}") | |
) | |
return wav_tmp.name | |
if __name__ == "__main__": | |
for name in logging.root.manager.loggerDict: | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.INFO) | |
logging.root.setLevel(logging.INFO) | |
GlobalState.config = tyro.cli(ServingConfig) | |
GlobalState.tts = TTS(seed=GlobalState.config.seed) | |
app.add_middleware( | |
fastapi.middleware.cors.CORSMiddleware, | |
allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=GlobalState.config.port, | |
log_level="info", | |
) | |