import json import os from pathlib import Path from typing import Literal import gradio as gr import numpy as np from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, get_stt_model, ) import soundfile as sf from gradio.utils import get_space from pydantic import BaseModel import httpx load_dotenv() curr_dir = Path(__file__).parent stt_model = get_stt_model() conversations = {} barks = { "chiahuahua-short": sf.read(curr_dir / "chiahuahua-short.mp3"), "chiahuahua-long": sf.read(curr_dir / "chiahuahua-long.mp3"), "dachshund-short": sf.read(curr_dir / "dachshund-short.mp3"), "dachshund-long": sf.read(curr_dir / "dachshund-long.mp3"), "golden-retriever-short": sf.read(curr_dir / "golden-retriever-short.mp3"), "golden-retriever-long": sf.read(curr_dir / "golden-retriever-long.mp3"), } for k, v in barks.items(): # Convert to mono if stereo by averaging channels, otherwise just use as is audio_data = v[0] if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: # Convert stereo to mono by averaging channels audio_data = np.mean(audio_data, axis=1) barks[k] = (v[1], audio_data.astype(np.float32)) def response( audio: tuple[int, np.ndarray], breed: Literal["chiahuahua", "dachshund", "golden-retriever"], ): response = [] prompt = stt_model.stt(audio) response.append({"role": "user", "content": prompt}) length = "long" if len(prompt.split(" ")) > 10 else "short" file_name = f"{breed}-{length}" response.append( { "role": "assistant", "content": f"/files/{file_name}.mp3", } ) audio_ = barks[file_name] yield audio_, AdditionalOutputs(response) stream = Stream( modality="audio", mode="send-receive", handler=ReplyOnPause(response), additional_outputs_handler=lambda a, b: b, additional_inputs=[ gr.Dropdown(choices=["chiahuahua", "dachshund", "golden-retriever"]) ], additional_outputs=[gr.JSON()], rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, concurrency_limit=20 if get_space() else None, time_limit=600 if get_space() else None, ) class InputData(BaseModel): webrtc_id: str breed: Literal["chiahuahua", "dachshund", "golden-retriever"] client = httpx.AsyncClient() app = FastAPI() stream.mount(app) @app.get("/") async def _(): turn_key_id = os.environ.get("TURN_TOKEN_ID") turn_key_api_token = os.environ.get("TURN_API_TOKEN") ttl = 600 response = await client.post( f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers", headers={ "Authorization": f"Bearer {turn_key_api_token}", "Content-Type": "application/json", }, json={"ttl": ttl}, ) if response.is_success: rtc_config = response.json() else: raise Exception( f"Failed to get TURN credentials: {response.status_code} {response.text}" ) html_content = (curr_dir / "index.html").read_text() html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) return HTMLResponse(content=html_content, status_code=200) @app.get("/files/{file_name}") async def _(file_name: str): print("file_name", file_name) if Path(file_name).name.replace(".mp3", "") not in barks: raise HTTPException(status_code=404, detail="File not found") return FileResponse(curr_dir / file_name) @app.post("/input_hook") async def _(body: InputData): stream.set_input(body.webrtc_id, body.breed) return {"status": "ok"} @app.get("/outputs") def _(webrtc_id: str): async def output_stream(): async for output in stream.output_stream(webrtc_id): messages = output.args[0] for msg in messages: yield f"event: output\ndata: {json.dumps(msg)}\n\n" return StreamingResponse(output_stream(), media_type="text/event-stream") if __name__ == "__main__": import os if (mode := os.getenv("MODE")) == "UI": stream.ui.launch(server_port=7860) elif mode == "PHONE": stream.fastphone(host="0.0.0.0", port=7860) else: import uvicorn uvicorn.run(app, host="0.0.0.0" if get_space() else "127.0.0.1", port=7860)