File size: 4,437 Bytes
af5779f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d8d346
 
 
af5779f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import os
from pathlib import Path
from typing import Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse
from fastrtc import (
    AdditionalOutputs,
    ReplyOnPause,
    Stream,
    get_stt_model,
)
import soundfile as sf
from gradio.utils import get_space
from pydantic import BaseModel
import httpx

load_dotenv()


curr_dir = Path(__file__).parent

stt_model = get_stt_model()

conversations = {}

barks = {
    "chiahuahua-short": sf.read(curr_dir / "chiahuahua-short.mp3"),
    "chiahuahua-long": sf.read(curr_dir / "chiahuahua-long.mp3"),
    "dachshund-short": sf.read(curr_dir / "dachshund-short.mp3"),
    "dachshund-long": sf.read(curr_dir / "dachshund-long.mp3"),
    "golden-retriever-short": sf.read(curr_dir / "golden-retriever-short.mp3"),
    "golden-retriever-long": sf.read(curr_dir / "golden-retriever-long.mp3"),
}
for k, v in barks.items():
    # Convert to mono if stereo by averaging channels, otherwise just use as is
    audio_data = v[0]
    if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
        # Convert stereo to mono by averaging channels
        audio_data = np.mean(audio_data, axis=1)
    barks[k] = (v[1], audio_data.astype(np.float32))


def response(
    audio: tuple[int, np.ndarray],
    breed: Literal["chiahuahua", "dachshund", "golden-retriever"],
):
    response = []
    prompt = stt_model.stt(audio)
    response.append({"role": "user", "content": prompt})
    length = "long" if len(prompt.split(" ")) > 10 else "short"
    file_name = f"{breed}-{length}"
    response.append(
        {
            "role": "assistant",
            "content": f"/files/{file_name}.mp3",
        }
    )
    audio_ = barks[file_name]
    yield audio_, AdditionalOutputs(response)


stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=ReplyOnPause(response),
    additional_outputs_handler=lambda a, b: b,
    additional_inputs=[
        gr.Dropdown(choices=["chiahuahua", "dachshund", "golden-retriever"])
    ],
    additional_outputs=[gr.JSON()],
    rtc_configuration=None,
    concurrency_limit=20 if get_space() else None,
    time_limit=600 if get_space() else None,
)


class InputData(BaseModel):
    webrtc_id: str
    breed: Literal["chiahuahua", "dachshund", "golden-retriever"]


client = httpx.AsyncClient()
app = FastAPI()
stream.mount(app)


@app.get("/")
async def _():
    turn_key_id = os.environ.get("TURN_TOKEN_ID")
    turn_key_api_token = os.environ.get("TURN_API_TOKEN")
    ttl = 600

    response = await client.post(
        f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers",
        headers={
            "Authorization": f"Bearer {turn_key_api_token}",
            "Content-Type": "application/json",
        },
        json={"ttl": ttl},
    )
    if response.is_success:
        rtc_config = response.json()
    else:
        raise Exception(
            f"Failed to get TURN credentials: {response.status_code} {response.text}"
        )
    html_content = (curr_dir / "index.html").read_text()
    html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
    return HTMLResponse(content=html_content, status_code=200)


@app.get("/files/{file_name}")
async def _(file_name: str):
    print("file_name", file_name)
    if Path(file_name).name.replace(".mp3", "") not in barks:
        raise HTTPException(status_code=404, detail="File not found")
    return FileResponse(curr_dir / file_name)


@app.post("/input_hook")
async def _(body: InputData):
    stream.set_input(body.webrtc_id, body.breed)
    return {"status": "ok"}


@app.get("/outputs")
def _(webrtc_id: str):
    async def output_stream():
        async for output in stream.output_stream(webrtc_id):
            messages = output.args[0]
            for msg in messages:
                yield f"event: output\ndata: {json.dumps(msg)}\n\n"

    return StreamingResponse(output_stream(), media_type="text/event-stream")


if __name__ == "__main__":
    import os

    if (mode := os.getenv("MODE")) == "UI":
        stream.ui.launch(server_port=7860)
    elif mode == "PHONE":
        stream.fastphone(host="0.0.0.0", port=7860)
    else:
        import uvicorn

        uvicorn.run(app, host="0.0.0.0" if get_space() else "127.0.0.1", port=7860)