freddyaboulton's picture
Upload folder using huggingface_hub
afd6773 verified
import json
import os
from pathlib import Path
from typing import Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
get_stt_model,
)
import soundfile as sf
from gradio.utils import get_space
from pydantic import BaseModel
import httpx
load_dotenv()
curr_dir = Path(__file__).parent
stt_model = get_stt_model()
conversations = {}
barks = {
"chiahuahua-short": sf.read(curr_dir / "chiahuahua-short.mp3"),
"chiahuahua-long": sf.read(curr_dir / "chiahuahua-long.mp3"),
"dachshund-short": sf.read(curr_dir / "dachshund-short.mp3"),
"dachshund-long": sf.read(curr_dir / "dachshund-long.mp3"),
"golden-retriever-short": sf.read(curr_dir / "golden-retriever-short.mp3"),
"golden-retriever-long": sf.read(curr_dir / "golden-retriever-long.mp3"),
}
for k, v in barks.items():
# Convert to mono if stereo by averaging channels, otherwise just use as is
audio_data = v[0]
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
# Convert stereo to mono by averaging channels
audio_data = np.mean(audio_data, axis=1)
barks[k] = (v[1], audio_data.astype(np.float32))
def response(
audio: tuple[int, np.ndarray],
breed: Literal["chiahuahua", "dachshund", "golden-retriever"],
):
response = []
prompt = stt_model.stt(audio)
response.append({"role": "user", "content": prompt})
length = "long" if len(prompt.split(" ")) > 10 else "short"
file_name = f"{breed}-{length}"
response.append(
{
"role": "assistant",
"content": f"/files/{file_name}.mp3",
}
)
audio_ = barks[file_name]
yield audio_, AdditionalOutputs(response)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
additional_outputs_handler=lambda a, b: b,
additional_inputs=[
gr.Dropdown(choices=["chiahuahua", "dachshund", "golden-retriever"])
],
additional_outputs=[gr.JSON()],
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
concurrency_limit=20 if get_space() else None,
time_limit=600 if get_space() else None,
)
class InputData(BaseModel):
webrtc_id: str
breed: Literal["chiahuahua", "dachshund", "golden-retriever"]
client = httpx.AsyncClient()
app = FastAPI()
stream.mount(app)
@app.get("/")
async def _():
turn_key_id = os.environ.get("TURN_TOKEN_ID")
turn_key_api_token = os.environ.get("TURN_API_TOKEN")
ttl = 600
response = await client.post(
f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers",
headers={
"Authorization": f"Bearer {turn_key_api_token}",
"Content-Type": "application/json",
},
json={"ttl": ttl},
)
if response.is_success:
rtc_config = response.json()
else:
raise Exception(
f"Failed to get TURN credentials: {response.status_code} {response.text}"
)
html_content = (curr_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content, status_code=200)
@app.get("/files/{file_name}")
async def _(file_name: str):
print("file_name", file_name)
if Path(file_name).name.replace(".mp3", "") not in barks:
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(curr_dir / file_name)
@app.post("/input_hook")
async def _(body: InputData):
stream.set_input(body.webrtc_id, body.breed)
return {"status": "ok"}
@app.get("/outputs")
def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
messages = output.args[0]
for msg in messages:
yield f"event: output\ndata: {json.dumps(msg)}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0" if get_space() else "127.0.0.1", port=7860)