Spaces:
Running
Running
File size: 4,437 Bytes
af5779f 4d8d346 af5779f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import json
import os
from pathlib import Path
from typing import Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
get_stt_model,
)
import soundfile as sf
from gradio.utils import get_space
from pydantic import BaseModel
import httpx
load_dotenv()
curr_dir = Path(__file__).parent
stt_model = get_stt_model()
conversations = {}
barks = {
"chiahuahua-short": sf.read(curr_dir / "chiahuahua-short.mp3"),
"chiahuahua-long": sf.read(curr_dir / "chiahuahua-long.mp3"),
"dachshund-short": sf.read(curr_dir / "dachshund-short.mp3"),
"dachshund-long": sf.read(curr_dir / "dachshund-long.mp3"),
"golden-retriever-short": sf.read(curr_dir / "golden-retriever-short.mp3"),
"golden-retriever-long": sf.read(curr_dir / "golden-retriever-long.mp3"),
}
for k, v in barks.items():
# Convert to mono if stereo by averaging channels, otherwise just use as is
audio_data = v[0]
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
# Convert stereo to mono by averaging channels
audio_data = np.mean(audio_data, axis=1)
barks[k] = (v[1], audio_data.astype(np.float32))
def response(
audio: tuple[int, np.ndarray],
breed: Literal["chiahuahua", "dachshund", "golden-retriever"],
):
response = []
prompt = stt_model.stt(audio)
response.append({"role": "user", "content": prompt})
length = "long" if len(prompt.split(" ")) > 10 else "short"
file_name = f"{breed}-{length}"
response.append(
{
"role": "assistant",
"content": f"/files/{file_name}.mp3",
}
)
audio_ = barks[file_name]
yield audio_, AdditionalOutputs(response)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
additional_outputs_handler=lambda a, b: b,
additional_inputs=[
gr.Dropdown(choices=["chiahuahua", "dachshund", "golden-retriever"])
],
additional_outputs=[gr.JSON()],
rtc_configuration=None,
concurrency_limit=20 if get_space() else None,
time_limit=600 if get_space() else None,
)
class InputData(BaseModel):
webrtc_id: str
breed: Literal["chiahuahua", "dachshund", "golden-retriever"]
client = httpx.AsyncClient()
app = FastAPI()
stream.mount(app)
@app.get("/")
async def _():
turn_key_id = os.environ.get("TURN_TOKEN_ID")
turn_key_api_token = os.environ.get("TURN_API_TOKEN")
ttl = 600
response = await client.post(
f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers",
headers={
"Authorization": f"Bearer {turn_key_api_token}",
"Content-Type": "application/json",
},
json={"ttl": ttl},
)
if response.is_success:
rtc_config = response.json()
else:
raise Exception(
f"Failed to get TURN credentials: {response.status_code} {response.text}"
)
html_content = (curr_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content, status_code=200)
@app.get("/files/{file_name}")
async def _(file_name: str):
print("file_name", file_name)
if Path(file_name).name.replace(".mp3", "") not in barks:
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(curr_dir / file_name)
@app.post("/input_hook")
async def _(body: InputData):
stream.set_input(body.webrtc_id, body.breed)
return {"status": "ok"}
@app.get("/outputs")
def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
messages = output.args[0]
for msg in messages:
yield f"event: output\ndata: {json.dumps(msg)}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0" if get_space() else "127.0.0.1", port=7860)
|