Spaces:
Running
Running
import json | |
import os | |
from pathlib import Path | |
from typing import Literal | |
import gradio as gr | |
import numpy as np | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse | |
from fastrtc import ( | |
AdditionalOutputs, | |
ReplyOnPause, | |
Stream, | |
get_stt_model, | |
) | |
import soundfile as sf | |
from gradio.utils import get_space | |
from pydantic import BaseModel | |
import httpx | |
load_dotenv() | |
curr_dir = Path(__file__).parent | |
stt_model = get_stt_model() | |
conversations = {} | |
barks = { | |
"chiahuahua-short": sf.read(curr_dir / "chiahuahua-short.mp3"), | |
"chiahuahua-long": sf.read(curr_dir / "chiahuahua-long.mp3"), | |
"dachshund-short": sf.read(curr_dir / "dachshund-short.mp3"), | |
"dachshund-long": sf.read(curr_dir / "dachshund-long.mp3"), | |
"golden-retriever-short": sf.read(curr_dir / "golden-retriever-short.mp3"), | |
"golden-retriever-long": sf.read(curr_dir / "golden-retriever-long.mp3"), | |
} | |
for k, v in barks.items(): | |
# Convert to mono if stereo by averaging channels, otherwise just use as is | |
audio_data = v[0] | |
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: | |
# Convert stereo to mono by averaging channels | |
audio_data = np.mean(audio_data, axis=1) | |
barks[k] = (v[1], audio_data.astype(np.float32)) | |
def response( | |
audio: tuple[int, np.ndarray], | |
breed: Literal["chiahuahua", "dachshund", "golden-retriever"], | |
): | |
response = [] | |
prompt = stt_model.stt(audio) | |
response.append({"role": "user", "content": prompt}) | |
length = "long" if len(prompt.split(" ")) > 10 else "short" | |
file_name = f"{breed}-{length}" | |
response.append( | |
{ | |
"role": "assistant", | |
"content": f"/files/{file_name}.mp3", | |
} | |
) | |
audio_ = barks[file_name] | |
yield audio_, AdditionalOutputs(response) | |
stream = Stream( | |
modality="audio", | |
mode="send-receive", | |
handler=ReplyOnPause(response), | |
additional_outputs_handler=lambda a, b: b, | |
additional_inputs=[ | |
gr.Dropdown(choices=["chiahuahua", "dachshund", "golden-retriever"]) | |
], | |
additional_outputs=[gr.JSON()], | |
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, | |
concurrency_limit=20 if get_space() else None, | |
time_limit=600 if get_space() else None, | |
) | |
class InputData(BaseModel): | |
webrtc_id: str | |
breed: Literal["chiahuahua", "dachshund", "golden-retriever"] | |
client = httpx.AsyncClient() | |
app = FastAPI() | |
stream.mount(app) | |
async def _(): | |
turn_key_id = os.environ.get("TURN_TOKEN_ID") | |
turn_key_api_token = os.environ.get("TURN_API_TOKEN") | |
ttl = 600 | |
response = await client.post( | |
f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers", | |
headers={ | |
"Authorization": f"Bearer {turn_key_api_token}", | |
"Content-Type": "application/json", | |
}, | |
json={"ttl": ttl}, | |
) | |
if response.is_success: | |
rtc_config = response.json() | |
else: | |
raise Exception( | |
f"Failed to get TURN credentials: {response.status_code} {response.text}" | |
) | |
html_content = (curr_dir / "index.html").read_text() | |
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) | |
return HTMLResponse(content=html_content, status_code=200) | |
async def _(file_name: str): | |
print("file_name", file_name) | |
if Path(file_name).name.replace(".mp3", "") not in barks: | |
raise HTTPException(status_code=404, detail="File not found") | |
return FileResponse(curr_dir / file_name) | |
async def _(body: InputData): | |
stream.set_input(body.webrtc_id, body.breed) | |
return {"status": "ok"} | |
def _(webrtc_id: str): | |
async def output_stream(): | |
async for output in stream.output_stream(webrtc_id): | |
messages = output.args[0] | |
for msg in messages: | |
yield f"event: output\ndata: {json.dumps(msg)}\n\n" | |
return StreamingResponse(output_stream(), media_type="text/event-stream") | |
if __name__ == "__main__": | |
import os | |
if (mode := os.getenv("MODE")) == "UI": | |
stream.ui.launch(server_port=7860) | |
elif mode == "PHONE": | |
stream.fastphone(host="0.0.0.0", port=7860) | |
else: | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0" if get_space() else "127.0.0.1", port=7860) | |