freddyaboulton HF Staff commited on
Commit
79dcba2
·
verified ·
1 Parent(s): 8e4bcce

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README_gradio.md +15 -0
  2. app.py +182 -0
  3. index.html +298 -0
  4. requirements.txt +4 -0
README_gradio.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Talk to Gemini (Gradio UI)
3
+ emoji: ♊️
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.16.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Talk to Gemini (Gradio UI)
12
+ tags: [webrtc, websocket, gradio, secret|TWILIO_ACCOUNT_SID, secret|TWILIO_AUTH_TOKEN, secret|GEMINI_API_KEY]
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ import pathlib
6
+ from typing import AsyncGenerator, Literal
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ from dotenv import load_dotenv
11
+ from fastapi import FastAPI
12
+ from fastapi.responses import HTMLResponse
13
+ from fastrtc import (
14
+ AsyncStreamHandler,
15
+ Stream,
16
+ async_aggregate_bytes_to_16bit,
17
+ get_twilio_turn_credentials,
18
+ )
19
+ from google import genai
20
+ from google.genai.types import (
21
+ LiveConnectConfig,
22
+ PrebuiltVoiceConfig,
23
+ SpeechConfig,
24
+ VoiceConfig,
25
+ )
26
+ from gradio.utils import get_space
27
+ from pydantic import BaseModel
28
+
29
+ current_dir = pathlib.Path(__file__).parent
30
+
31
+ load_dotenv()
32
+
33
+
34
+ def encode_audio(data: np.ndarray) -> str:
35
+ """Encode Audio data to send to the server"""
36
+ return base64.b64encode(data.tobytes()).decode("UTF-8")
37
+
38
+
39
+ class GeminiHandler(AsyncStreamHandler):
40
+ """Handler for the Gemini API"""
41
+
42
+ def __init__(
43
+ self,
44
+ expected_layout: Literal["mono"] = "mono",
45
+ output_sample_rate: int = 24000,
46
+ output_frame_size: int = 480,
47
+ ) -> None:
48
+ super().__init__(
49
+ expected_layout,
50
+ output_sample_rate,
51
+ output_frame_size,
52
+ input_sample_rate=16000,
53
+ )
54
+ self.input_queue: asyncio.Queue = asyncio.Queue()
55
+ self.output_queue: asyncio.Queue = asyncio.Queue()
56
+ self.quit: asyncio.Event = asyncio.Event()
57
+
58
+ def copy(self) -> "GeminiHandler":
59
+ return GeminiHandler(
60
+ expected_layout="mono",
61
+ output_sample_rate=self.output_sample_rate,
62
+ output_frame_size=self.output_frame_size,
63
+ )
64
+
65
+ async def stream(self) -> AsyncGenerator[bytes, None]:
66
+ while not self.quit.is_set():
67
+ audio = await self.input_queue.get()
68
+ yield audio
69
+ return
70
+
71
+ async def connect(
72
+ self, api_key: str | None = None, voice_name: str | None = "Kore"
73
+ ) -> AsyncGenerator[bytes, None]:
74
+ """Connect to to genai server and start the stream"""
75
+ client = genai.Client(
76
+ api_key=api_key or os.getenv("GEMINI_API_KEY"),
77
+ http_options={"api_version": "v1alpha"},
78
+ )
79
+ config = LiveConnectConfig(
80
+ response_modalities=["AUDIO"], # type: ignore
81
+ speech_config=SpeechConfig(
82
+ voice_config=VoiceConfig(
83
+ prebuilt_voice_config=PrebuiltVoiceConfig(
84
+ voice_name=voice_name,
85
+ )
86
+ )
87
+ ),
88
+ )
89
+ async with client.aio.live.connect(
90
+ model="gemini-2.0-flash-exp", config=config
91
+ ) as session:
92
+ async for audio in session.start_stream(
93
+ stream=self.stream(), mime_type="audio/pcm"
94
+ ):
95
+ if audio.data:
96
+ yield audio.data
97
+
98
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
99
+ _, array = frame
100
+ array = array.squeeze()
101
+ audio_message = encode_audio(array)
102
+ self.input_queue.put_nowait(audio_message)
103
+
104
+ async def generator(self) -> None:
105
+ async for audio_response in async_aggregate_bytes_to_16bit(
106
+ self.connect(*self.latest_args[1:])
107
+ ):
108
+ self.output_queue.put_nowait(audio_response)
109
+
110
+ async def emit(self) -> tuple[int, np.ndarray]:
111
+ if not self.args_set.is_set():
112
+ await self.wait_for_args()
113
+ asyncio.create_task(self.generator())
114
+
115
+ array = await self.output_queue.get()
116
+ return (self.output_sample_rate, array)
117
+
118
+ def shutdown(self) -> None:
119
+ self.quit.set()
120
+ self.args_set.clear()
121
+ self.quit.clear()
122
+
123
+
124
+ stream = Stream(
125
+ modality="audio",
126
+ mode="send-receive",
127
+ handler=GeminiHandler(),
128
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
129
+ concurrency_limit=20 if get_space() else None,
130
+ additional_inputs=[
131
+ gr.Textbox(label="API Key", type="password", value=os.getenv("GEMINI_API_KEY")),
132
+ gr.Dropdown(
133
+ label="Voice",
134
+ choices=[
135
+ "Puck",
136
+ "Charon",
137
+ "Kore",
138
+ "Fenrir",
139
+ "Aoede",
140
+ ],
141
+ value="Puck",
142
+ ),
143
+ ],
144
+ )
145
+
146
+
147
+ class InputData(BaseModel):
148
+ webrtc_id: str
149
+ voice_name: str
150
+ api_key: str
151
+
152
+
153
+ app = FastAPI()
154
+
155
+ stream.mount(app)
156
+
157
+
158
+ @app.post("/input_hook")
159
+ async def _(body: InputData):
160
+ stream.set_input(body.webrtc_id, body.api_key, body.voice_name)
161
+ return {"status": "ok"}
162
+
163
+
164
+ @app.get("/")
165
+ async def index():
166
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
167
+ html_content = (current_dir / "index.html").read_text()
168
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
169
+ return HTMLResponse(content=html_content)
170
+
171
+
172
+ if __name__ == "__main__":
173
+ import os
174
+
175
+ if (mode := os.getenv("MODE")) == "UI":
176
+ stream.ui.launch(server_port=7860, server_name="0.0.0.0")
177
+ elif mode == "PHONE":
178
+ stream.fastphone(host="0.0.0.0", port=7860)
179
+ else:
180
+ import uvicorn
181
+
182
+ uvicorn.run(app, host="0.0.0.0", port=7860)
index.html ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Gemini Voice Chat</title>
8
+ <style>
9
+ :root {
10
+ --color-accent: #6366f1;
11
+ --color-background: #0f172a;
12
+ --color-surface: #1e293b;
13
+ --color-text: #e2e8f0;
14
+ --boxSize: 8px;
15
+ --gutter: 4px;
16
+ }
17
+
18
+ body {
19
+ margin: 0;
20
+ padding: 0;
21
+ background-color: var(--color-background);
22
+ color: var(--color-text);
23
+ font-family: system-ui, -apple-system, sans-serif;
24
+ min-height: 100vh;
25
+ display: flex;
26
+ flex-direction: column;
27
+ align-items: center;
28
+ justify-content: center;
29
+ }
30
+
31
+ .container {
32
+ width: 90%;
33
+ max-width: 800px;
34
+ background-color: var(--color-surface);
35
+ padding: 2rem;
36
+ border-radius: 1rem;
37
+ box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.25);
38
+ }
39
+
40
+ .wave-container {
41
+ position: relative;
42
+ display: flex;
43
+ min-height: 100px;
44
+ max-height: 128px;
45
+ justify-content: center;
46
+ align-items: center;
47
+ margin: 2rem 0;
48
+ }
49
+
50
+ .box-container {
51
+ display: flex;
52
+ justify-content: space-between;
53
+ height: 64px;
54
+ width: 100%;
55
+ }
56
+
57
+ .box {
58
+ height: 100%;
59
+ width: var(--boxSize);
60
+ background: var(--color-accent);
61
+ border-radius: 8px;
62
+ transition: transform 0.05s ease;
63
+ }
64
+
65
+ .controls {
66
+ display: grid;
67
+ gap: 1rem;
68
+ margin-bottom: 2rem;
69
+ }
70
+
71
+ .input-group {
72
+ display: flex;
73
+ flex-direction: column;
74
+ gap: 0.5rem;
75
+ }
76
+
77
+ label {
78
+ font-size: 0.875rem;
79
+ font-weight: 500;
80
+ }
81
+
82
+ input,
83
+ select {
84
+ padding: 0.75rem;
85
+ border-radius: 0.5rem;
86
+ border: 1px solid rgba(255, 255, 255, 0.1);
87
+ background-color: var(--color-background);
88
+ color: var(--color-text);
89
+ font-size: 1rem;
90
+ }
91
+
92
+ button {
93
+ padding: 1rem 2rem;
94
+ border-radius: 0.5rem;
95
+ border: none;
96
+ background-color: var(--color-accent);
97
+ color: white;
98
+ font-weight: 600;
99
+ cursor: pointer;
100
+ transition: all 0.2s ease;
101
+ }
102
+
103
+ button:hover {
104
+ opacity: 0.9;
105
+ transform: translateY(-1px);
106
+ }
107
+ </style>
108
+ </head>
109
+
110
+
111
+ <body>
112
+ <div style="text-align: center">
113
+ <h1>Gemini Voice Chat</h1>
114
+ <p>Speak with Gemini using real-time audio streaming</p>
115
+ <p>
116
+ Get a Gemini API key
117
+ <a href="https://ai.google.dev/gemini-api/docs/api-key">here</a>
118
+ </p>
119
+ </div>
120
+ <div class="container">
121
+ <div class="controls">
122
+ <div class="input-group">
123
+ <label for="api-key">API Key</label>
124
+ <input type="password" id="api-key" placeholder="Enter your API key">
125
+ </div>
126
+ <div class="input-group">
127
+ <label for="voice">Voice</label>
128
+ <select id="voice">
129
+ <option value="Puck">Puck</option>
130
+ <option value="Charon">Charon</option>
131
+ <option value="Kore">Kore</option>
132
+ <option value="Fenrir">Fenrir</option>
133
+ <option value="Aoede">Aoede</option>
134
+ </select>
135
+ </div>
136
+ </div>
137
+
138
+ <div class="wave-container">
139
+ <div class="box-container">
140
+ <!-- Boxes will be dynamically added here -->
141
+ </div>
142
+ </div>
143
+
144
+ <button id="start-button">Start Recording</button>
145
+ </div>
146
+
147
+ <audio id="audio-output"></audio>
148
+
149
+ <script>
150
+ let peerConnection;
151
+ let audioContext;
152
+ let dataChannel;
153
+ let isRecording = false;
154
+ let webrtc_id;
155
+
156
+ const startButton = document.getElementById('start-button');
157
+ const apiKeyInput = document.getElementById('api-key');
158
+ const voiceSelect = document.getElementById('voice');
159
+ const audioOutput = document.getElementById('audio-output');
160
+ const boxContainer = document.querySelector('.box-container');
161
+
162
+ const numBars = 32;
163
+ for (let i = 0; i < numBars; i++) {
164
+ const box = document.createElement('div');
165
+ box.className = 'box';
166
+ boxContainer.appendChild(box);
167
+ }
168
+
169
+ async function setupWebRTC() {
170
+ const config = __RTC_CONFIGURATION__;
171
+ peerConnection = new RTCPeerConnection(config);
172
+ webrtc_id = Math.random().toString(36).substring(7);
173
+
174
+ try {
175
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
176
+ stream.getTracks().forEach(track => peerConnection.addTrack(track, stream));
177
+
178
+ audioContext = new AudioContext();
179
+ analyser = audioContext.createAnalyser();
180
+ analyser.fftSize = 64;
181
+ analyser.smoothingTimeConstant = 0.8;
182
+ dataArray = new Uint8Array(analyser.frequencyBinCount);
183
+
184
+ // Handle incoming audio
185
+ peerConnection.addEventListener('track', (evt) => {
186
+ if (audioOutput && audioOutput.srcObject !== evt.streams[0]) {
187
+ audioOutput.srcObject = evt.streams[0];
188
+ audioOutput.play();
189
+
190
+ // Set up audio visualization on the output stream
191
+ audioContext = new AudioContext();
192
+ analyser = audioContext.createAnalyser();
193
+ const source = audioContext.createMediaStreamSource(evt.streams[0]);
194
+ source.connect(analyser);
195
+ analyser.fftSize = 2048;
196
+ dataArray = new Uint8Array(analyser.frequencyBinCount);
197
+ updateVisualization();
198
+ }
199
+ });
200
+
201
+ // Create data channel for messages
202
+ dataChannel = peerConnection.createDataChannel('text');
203
+ dataChannel.onmessage = handleMessage;
204
+
205
+ // Create and send offer
206
+ const offer = await peerConnection.createOffer();
207
+ await peerConnection.setLocalDescription(offer);
208
+
209
+ await new Promise((resolve) => {
210
+ if (peerConnection.iceGatheringState === "complete") {
211
+ resolve();
212
+ } else {
213
+ const checkState = () => {
214
+ if (peerConnection.iceGatheringState === "complete") {
215
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
216
+ resolve();
217
+ }
218
+ };
219
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
220
+ }
221
+ });
222
+
223
+ const response = await fetch('/webrtc/offer', {
224
+ method: 'POST',
225
+ headers: { 'Content-Type': 'application/json' },
226
+ body: JSON.stringify({
227
+ sdp: peerConnection.localDescription.sdp,
228
+ type: peerConnection.localDescription.type,
229
+ webrtc_id: webrtc_id,
230
+ })
231
+ });
232
+
233
+ const serverResponse = await response.json();
234
+ await peerConnection.setRemoteDescription(serverResponse);
235
+ } catch (err) {
236
+ console.error('Error setting up WebRTC:', err);
237
+ }
238
+ }
239
+
240
+ function handleMessage(event) {
241
+ const eventJson = JSON.parse(event.data);
242
+ if (eventJson.type === "send_input") {
243
+ fetch('/input_hook', {
244
+ method: 'POST',
245
+ headers: {
246
+ 'Content-Type': 'application/json',
247
+ },
248
+ body: JSON.stringify({
249
+ webrtc_id: webrtc_id,
250
+ api_key: apiKeyInput.value,
251
+ voice_name: voiceSelect.value
252
+ })
253
+ });
254
+ }
255
+ }
256
+
257
+ function updateVisualization() {
258
+ if (!analyser) return;
259
+
260
+ analyser.getByteFrequencyData(dataArray);
261
+ const bars = document.querySelectorAll('.box');
262
+
263
+ for (let i = 0; i < bars.length; i++) {
264
+ const barHeight = (dataArray[i] / 255) * 2;
265
+ bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
266
+ }
267
+
268
+ animationId = requestAnimationFrame(updateVisualization);
269
+ }
270
+
271
+ function stopWebRTC() {
272
+ if (peerConnection) {
273
+ peerConnection.close();
274
+ }
275
+ if (animationId) {
276
+ cancelAnimationFrame(animationId);
277
+ }
278
+ if (audioContext) {
279
+ audioContext.close();
280
+ }
281
+ }
282
+
283
+ startButton.addEventListener('click', () => {
284
+ if (!isRecording) {
285
+ setupWebRTC();
286
+ startButton.textContent = 'Stop Recording';
287
+ startButton.classList.add('recording');
288
+ } else {
289
+ stopWebRTC();
290
+ startButton.textContent = 'Start Recording';
291
+ startButton.classList.remove('recording');
292
+ }
293
+ isRecording = !isRecording;
294
+ });
295
+ </script>
296
+ </body>
297
+
298
+ </html>
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastrtc
2
+ python-dotenv
3
+ google-genai
4
+ twilio