freddyaboulton's picture
Upload folder using huggingface_hub
aa46ebc verified
raw
history blame
4.44 kB
import json
import os
from pathlib import Path
from typing import Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
get_stt_model,
)
import soundfile as sf
from gradio.utils import get_space
from pydantic import BaseModel
import httpx
load_dotenv()
curr_dir = Path(__file__).parent
stt_model = get_stt_model()
conversations = {}
barks = {
"chiahuahua-short": sf.read(curr_dir / "chiahuahua-short.mp3"),
"chiahuahua-long": sf.read(curr_dir / "chiahuahua-long.mp3"),
"dachshund-short": sf.read(curr_dir / "dachshund-short.mp3"),
"dachshund-long": sf.read(curr_dir / "dachshund-long.mp3"),
"golden-retriever-short": sf.read(curr_dir / "golden-retriever-short.mp3"),
"golden-retriever-long": sf.read(curr_dir / "golden-retriever-long.mp3"),
}
for k, v in barks.items():
# Convert to mono if stereo by averaging channels, otherwise just use as is
audio_data = v[0]
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
# Convert stereo to mono by averaging channels
audio_data = np.mean(audio_data, axis=1)
barks[k] = (v[1], audio_data.astype(np.float32))
def response(
audio: tuple[int, np.ndarray],
breed: Literal["chiahuahua", "dachshund", "golden-retriever"],
):
response = []
prompt = stt_model.stt(audio)
response.append({"role": "user", "content": prompt})
length = "long" if len(prompt.split(" ")) > 10 else "short"
file_name = f"{breed}-{length}"
response.append(
{
"role": "assistant",
"content": f"/files/{file_name}.mp3",
}
)
audio_ = barks[file_name]
yield audio_, AdditionalOutputs(response)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
additional_outputs_handler=lambda a, b: b,
additional_inputs=[
gr.Dropdown(choices=["chiahuahua", "dachshund", "golden-retriever"])
],
additional_outputs=[gr.JSON()],
rtc_configuration=None,
concurrency_limit=20 if get_space() else None,
time_limit=600 if get_space() else None,
)
class InputData(BaseModel):
webrtc_id: str
breed: Literal["chiahuahua", "dachshund", "golden-retriever"]
client = httpx.AsyncClient()
app = FastAPI()
stream.mount(app)
@app.get("/")
async def _():
turn_key_id = os.environ.get("TURN_TOKEN_ID")
turn_key_api_token = os.environ.get("TURN_API_TOKEN")
ttl = 600
response = await client.post(
f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers",
headers={
"Authorization": f"Bearer {turn_key_api_token}",
"Content-Type": "application/json",
},
json={"ttl": ttl},
)
if response.is_success:
rtc_config = response.json()
else:
raise Exception(
f"Failed to get TURN credentials: {response.status_code} {response.text}"
)
html_content = (curr_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content, status_code=200)
@app.get("/files/{file_name}")
async def _(file_name: str):
print("file_name", file_name)
if Path(file_name).name.replace(".mp3", "") not in barks:
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(curr_dir / file_name)
@app.post("/input_hook")
async def _(body: InputData):
stream.set_input(body.webrtc_id, body.breed)
return {"status": "ok"}
@app.get("/outputs")
def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
messages = output.args[0]
for msg in messages:
yield f"event: output\ndata: {json.dumps(msg)}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0" if get_space() else "127.0.0.1", port=7860)