mgokg commited on
Commit
fc5e5f2
·
verified ·
1 Parent(s): b3eb4a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2025 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ ## Setup
17
+
18
+ The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
19
+
20
+ ```
21
+ brew uninstall ffmpeg
22
+ brew install ffmpeg@6
23
+ brew link ffmpeg@6
24
+ ```
25
+
26
+ Create a virtual python environment, then install the dependencies for this script:
27
+
28
+ ```
29
+ pip install websockets numpy gradio-webrtc "gradio>=5.9.1"
30
+ ```
31
+
32
+ If installation fails it may be
33
+
34
+ Before running this script, ensure the `GOOGLE_API_KEY` environment
35
+
36
+ ```
37
+ $ export GOOGLE_API_KEY ='add your key here'
38
+ ```
39
+
40
+ You can get an api-key from Google AI Studio (https://aistudio.google.com/apikey)
41
+
42
+ ## Run
43
+
44
+ To run the script:
45
+
46
+ ```
47
+ python gemini_gradio_audio.py
48
+ ```
49
+
50
+ On the gradio page (http://127.0.0.1:7860/) click record, and talk, gemini will reply. But note that interruptions
51
+ don't work.
52
+
53
+ """
54
+
55
+ import os
56
+ import base64
57
+ import json
58
+ import numpy as np
59
+ import gradio as gr
60
+ import websockets.sync.client
61
+ from gradio_webrtc import StreamHandler, WebRTC
62
+
63
+ __version__ = "0.0.3"
64
+
65
+ #KEY_NAME="AIzaSyCWPviRPxj8IMLaijLGbRIsio3dO2rp3rU"
66
+
67
+ # Configuration and Utilities
68
+ class GeminiConfig:
69
+ """Configuration settings for Gemini API."""
70
+ def __init__(self):
71
+ self.api_key = os.getenv(KEY_NAME)
72
+ self.host = "generativelanguage.googleapis.com"
73
+ self.model = "models/gemini-2.0-flash-exp"
74
+ self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
75
+
76
+ class AudioProcessor:
77
+ """Handles encoding and decoding of audio data."""
78
+ @staticmethod
79
+ def encode_audio(data, sample_rate):
80
+ """Encodes audio data to base64."""
81
+ encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
82
+ return {
83
+ "realtimeInput": {
84
+ "mediaChunks": [
85
+ {
86
+ "mimeType": f"audio/pcm;rate={sample_rate}",
87
+ "data": encoded,
88
+ }
89
+ ],
90
+ },
91
+ }
92
+
93
+ @staticmethod
94
+ def process_audio_response(data):
95
+ """Decodes audio data from base64."""
96
+ audio_data = base64.b64decode(data)
97
+ return np.frombuffer(audio_data, dtype=np.int16)
98
+
99
+ # Gemini Interaction Handler
100
+ class GeminiHandler(StreamHandler):
101
+ """Handles streaming interactions with the Gemini API."""
102
+ def __init__(self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480) -> None:
103
+ super().__init__(expected_layout, output_sample_rate, output_frame_size, input_sample_rate=24000)
104
+ self.config = GeminiConfig()
105
+ self.ws = None
106
+ self.all_output_data = None
107
+ self.audio_processor = AudioProcessor()
108
+
109
+ def copy(self):
110
+ """Creates a copy of the GeminiHandler instance."""
111
+ return GeminiHandler(
112
+ expected_layout=self.expected_layout,
113
+ output_sample_rate=self.output_sample_rate,
114
+ output_frame_size=self.output_frame_size,
115
+ )
116
+
117
+ def _initialize_websocket(self):
118
+ """Initializes the WebSocket connection to the Gemini API."""
119
+ try:
120
+ self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=3000)
121
+ initial_request = {"setup": {"model": self.config.model,"tools":[{"google_search": {}}]}}
122
+ self.ws.send(json.dumps(initial_request))
123
+ setup_response = json.loads(self.ws.recv())
124
+ print(f"Setup response: {setup_response}")
125
+ except websockets.exceptions.WebSocketException as e:
126
+ print(f"WebSocket connection failed: {str(e)}")
127
+ self.ws = None
128
+ except Exception as e:
129
+ print(f"Setup failed: {str(e)}")
130
+ self.ws = None
131
+
132
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
133
+ """Empfängt Audio-/Videodaten, kodiert sie und sendet sie an die Gemini API."""
134
+ try:
135
+ if not self.ws:
136
+ self._initialize_websocket()
137
+ if not self.ws: # Überprüfen, ob die Verbindung erfolgreich ist
138
+ print("WebSocket-Verbindung konnte nicht hergestellt werden.")
139
+ return # Frühzeitiger Rückkehr, wenn die Verbindung fehlschlägt
140
+
141
+ sample_rate, array = frame
142
+ message = {"realtimeInput": {"mediaChunks": []}}
143
+
144
+ if sample_rate > 0 and array is not None:
145
+ array = array.squeeze()
146
+ audio_data = self.audio_processor.encode_audio(array, self.output_sample_rate)
147
+ message["realtimeInput"]["mediaChunks"].append({
148
+ "mimeType": f"audio/pcm;rate={self.output_sample_rate}",
149
+ "data": audio_data["realtimeInput"]["mediaChunks"][0]["data"],
150
+ })
151
+
152
+ if message["realtimeInput"]["mediaChunks"]:
153
+ self.ws.send(json.dumps(message))
154
+ except Exception as e:
155
+ print(f"Fehler beim Empfangen: {str(e)}")
156
+ if self.ws:
157
+ self.ws.close()
158
+ self.ws = None
159
+
160
+
161
+ def _process_server_content(self, content):
162
+ """Processes audio output data from the WebSocket response."""
163
+ for part in content.get("parts", []):
164
+ data = part.get("inlineData", {}).get("data", "")
165
+ if data:
166
+ audio_array = self.audio_processor.process_audio_response(data)
167
+ if self.all_output_data is None:
168
+ self.all_output_data = audio_array
169
+ else:
170
+ self.all_output_data = np.concatenate((self.all_output_data, audio_array))
171
+
172
+ while self.all_output_data.shape[-1] >= self.output_frame_size:
173
+ yield (self.output_sample_rate, self.all_output_data[: self.output_frame_size].reshape(1, -1))
174
+ self.all_output_data = self.all_output_data[self.output_frame_size :]
175
+
176
+ def generator(self):
177
+ """Generates audio output from the WebSocket stream."""
178
+ while True:
179
+ if not self.ws:
180
+ print("WebSocket not connected")
181
+ yield None
182
+ continue
183
+
184
+ try:
185
+ message = self.ws.recv(timeout=30)
186
+ msg = json.loads(message)
187
+ if "serverContent" in msg:
188
+ content = msg["serverContent"].get("modelTurn", {})
189
+ yield from self._process_server_content(content)
190
+ except TimeoutError:
191
+ print("Timeout waiting for server response")
192
+ yield None
193
+ except Exception as e:
194
+ yield None
195
+
196
+ def emit(self) -> tuple[int, np.ndarray] | None:
197
+ """Emits the next audio chunk from the generator."""
198
+ if not self.ws:
199
+ return None
200
+ if not hasattr(self, "_generator"):
201
+ self._generator = self.generator()
202
+ try:
203
+ return next(self._generator)
204
+ except StopIteration:
205
+ self.reset()
206
+ return None
207
+
208
+ def reset(self) -> None:
209
+ """Resets the generator and output data."""
210
+ if hasattr(self, "_generator"):
211
+ delattr(self, "_generator")
212
+ self.all_output_data = None
213
+
214
+ def shutdown(self) -> None:
215
+ """Closes the WebSocket connection."""
216
+ if self.ws:
217
+ self.ws.close()
218
+
219
+ def check_connection(self):
220
+ """Checks if the WebSocket connection is active."""
221
+ try:
222
+ if not self.ws or self.ws.closed:
223
+ self._initialize_websocket()
224
+ return True
225
+ except Exception as e:
226
+ print(f"Connection check failed: {str(e)}")
227
+ return False
228
+
229
+ # Main Gradio Interface
230
+ def registry(
231
+ name: str,
232
+ token: str | None = None,
233
+ **kwargs
234
+ ):
235
+ """Sets up and returns the Gradio interface."""
236
+ api_key = token or os.environ.get(KEY_NAME)
237
+ if not api_key:
238
+ raise ValueError(f"{KEY_NAME} environment variable is not set.")
239
+
240
+ interface = gr.Blocks()
241
+ with interface:
242
+ with gr.Tabs():
243
+ with gr.TabItem("Voice Chat"):
244
+ gr.HTML(
245
+ """
246
+ <div style='text-align: left'>
247
+ <h1>Gemini API Voice Chat</h1>
248
+ </div>
249
+ """
250
+ )
251
+ gemini_handler = GeminiHandler()
252
+ with gr.Row():
253
+ audio = WebRTC(label="Voice Chat", modality="audio", mode="send-receive")
254
+
255
+ audio.stream(
256
+ gemini_handler,
257
+ inputs=[audio],
258
+ outputs=[audio],
259
+ time_limit=600,
260
+ concurrency_limit=10
261
+ )
262
+ return interface
263
+
264
+ # Launch the Gradio interface
265
+ gr.load(
266
+ name='gemini-2.0-flash-exp',
267
+ src=registry,
268
+ ).launch()