seawolf2357 commited on
Commit
2bc8e7a
·
verified ·
1 Parent(s): d1ea807

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import openai
9
+ from dotenv import load_dotenv
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import HTMLResponse, StreamingResponse
12
+ from fastrtc import (
13
+ AdditionalOutputs,
14
+ AsyncStreamHandler,
15
+ Stream,
16
+ get_twilio_turn_credentials,
17
+ wait_for_item,
18
+ )
19
+ from gradio.utils import get_space
20
+ from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
21
+
22
+ load_dotenv()
23
+
24
+ cur_dir = Path(__file__).parent
25
+
26
+ SAMPLE_RATE = 24000
27
+
28
+
29
+ class OpenAIHandler(AsyncStreamHandler):
30
+ def __init__(
31
+ self,
32
+ ) -> None:
33
+ super().__init__(
34
+ expected_layout="mono",
35
+ output_sample_rate=SAMPLE_RATE,
36
+ output_frame_size=480,
37
+ input_sample_rate=SAMPLE_RATE,
38
+ )
39
+ self.connection = None
40
+ self.output_queue = asyncio.Queue()
41
+
42
+ def copy(self):
43
+ return OpenAIHandler()
44
+
45
+ async def start_up(
46
+ self,
47
+ ):
48
+ """Connect to realtime API. Run forever in separate thread to keep connection open."""
49
+ self.client = openai.AsyncOpenAI()
50
+ async with self.client.beta.realtime.connect(
51
+ model="gpt-4o-mini-realtime-preview-2024-12-17"
52
+ ) as conn:
53
+ await conn.session.update(
54
+ session={"turn_detection": {"type": "server_vad"}}
55
+ )
56
+ self.connection = conn
57
+ async for event in self.connection:
58
+ if event.type == "response.audio_transcript.done":
59
+ await self.output_queue.put(AdditionalOutputs(event))
60
+ if event.type == "response.audio.delta":
61
+ await self.output_queue.put(
62
+ (
63
+ self.output_sample_rate,
64
+ np.frombuffer(
65
+ base64.b64decode(event.delta), dtype=np.int16
66
+ ).reshape(1, -1),
67
+ ),
68
+ )
69
+
70
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
71
+ if not self.connection:
72
+ return
73
+ _, array = frame
74
+ array = array.squeeze()
75
+ audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
76
+ await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
77
+
78
+ async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
79
+ return await wait_for_item(self.output_queue)
80
+
81
+ async def shutdown(self) -> None:
82
+ if self.connection:
83
+ await self.connection.close()
84
+ self.connection = None
85
+
86
+
87
+ def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
88
+ chatbot.append({"role": "assistant", "content": response.transcript})
89
+ return chatbot
90
+
91
+
92
+ chatbot = gr.Chatbot(type="messages")
93
+ latest_message = gr.Textbox(type="text", visible=False)
94
+ stream = Stream(
95
+ OpenAIHandler(),
96
+ mode="send-receive",
97
+ modality="audio",
98
+ additional_inputs=[chatbot],
99
+ additional_outputs=[chatbot],
100
+ additional_outputs_handler=update_chatbot,
101
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
102
+ concurrency_limit=5 if get_space() else None,
103
+ time_limit=90 if get_space() else None,
104
+ )
105
+
106
+ app = FastAPI()
107
+
108
+ stream.mount(app)
109
+
110
+
111
+ @app.get("/")
112
+ async def _():
113
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
114
+ html_content = (cur_dir / "index.html").read_text()
115
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
116
+ return HTMLResponse(content=html_content)
117
+
118
+
119
+ @app.get("/outputs")
120
+ def _(webrtc_id: str):
121
+ async def output_stream():
122
+ import json
123
+
124
+ async for output in stream.output_stream(webrtc_id):
125
+ s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
126
+ yield f"event: output\ndata: {s}\n\n"
127
+
128
+ return StreamingResponse(output_stream(), media_type="text/event-stream")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ import os
133
+
134
+ if (mode := os.getenv("MODE")) == "UI":
135
+ stream.ui.launch(server_port=7860)
136
+ elif mode == "PHONE":
137
+ stream.fastphone(host="0.0.0.0", port=7860)
138
+ else:
139
+ import uvicorn
140
+
141
+ uvicorn.run(app, host="0.0.0.0", port=7860)