mgokg commited on
Commit
eb44df9
·
verified ·
1 Parent(s): c1bc211

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -217
app.py CHANGED
@@ -1,241 +1,262 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright 2025 Google LLC
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- ## Setup
17
-
18
- The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
19
-
20
- ```
21
- brew uninstall ffmpeg
22
- brew install ffmpeg@6
23
- brew link ffmpeg@6
24
- ```
25
-
26
- Create a virtual python environment, then install the dependencies for this script:
27
-
28
- ```
29
- pip install websockets numpy gradio-webrtc "gradio>=5.9.1"
30
- ```
31
-
32
- If installation fails it may be
33
-
34
- Before running this script, ensure the `GOOGLE_API_KEY` environment
35
-
36
- ```
37
- $ export GOOGLE_API_KEY ='add your key here'
38
- ```
39
-
40
- You can get an api-key from Google AI Studio (https://aistudio.google.com/apikey)
41
-
42
- ## Run
43
-
44
- To run the script:
45
-
46
- ```
47
- python gemini_gradio_audio.py
48
- ```
49
-
50
- On the gradio page (http://127.0.0.1:7860/) click record, and talk, gemini will reply. But note that interruptions
51
- don't work.
52
-
53
- """
54
-
55
  import os
56
  import base64
57
  import json
58
  import numpy as np
59
  import gradio as gr
60
- import websockets.sync.client
61
- from gradio_webrtc import StreamHandler, WebRTC
 
 
 
 
 
 
 
 
62
 
63
  __version__ = "0.0.3"
64
 
65
- #KEY_NAME="AIzaSyCWPviRPxj8IMLaijLGbRIsio3dO2rp3rU"
 
66
 
67
  # Configuration and Utilities
68
  class GeminiConfig:
69
  """Configuration settings for Gemini API."""
 
70
  def __init__(self):
71
- self.api_key = os.environ.get(KEY_NAME)
 
 
72
  self.host = "generativelanguage.googleapis.com"
73
  self.model = "models/gemini-2.0-flash-exp"
74
- self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
 
 
75
 
76
  class AudioProcessor:
77
  """Handles encoding and decoding of audio data."""
 
78
  @staticmethod
79
- def encode_audio(data, sample_rate):
80
  """Encodes audio data to base64."""
 
 
 
81
  encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
82
- return {
83
- "realtimeInput": {
84
- "mediaChunks": [
85
- {
86
- "mimeType": f"audio/pcm;rate={sample_rate}",
87
- "data": encoded,
88
- }
89
- ],
90
- },
91
- }
92
 
93
  @staticmethod
94
- def process_audio_response(data):
95
  """Decodes audio data from base64."""
96
  audio_data = base64.b64decode(data)
97
  return np.frombuffer(audio_data, dtype=np.int16)
98
 
99
- # Gemini Interaction Handler
100
- class GeminiHandler(StreamHandler):
101
- """Handles streaming interactions with the Gemini API."""
102
- def __init__(self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480) -> None:
103
- super().__init__(expected_layout, output_sample_rate, output_frame_size, input_sample_rate=24000)
104
- self.config = GeminiConfig()
105
- self.ws = None
106
- self.all_output_data = None
107
- self.audio_processor = AudioProcessor()
108
 
109
- def copy(self):
110
- """Creates a copy of the GeminiHandler instance."""
111
- return GeminiHandler(
112
- expected_layout=self.expected_layout,
113
- output_sample_rate=self.output_sample_rate,
114
- output_frame_size=self.output_frame_size,
115
- )
116
-
117
- def _initialize_websocket(self):
118
- """Initializes the WebSocket connection to the Gemini API."""
119
- try:
120
- self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=3000)
121
- initial_request = {"setup": {"model": self.config.model,"tools":[{"google_search": {}}]}}
122
- self.ws.send(json.dumps(initial_request))
123
- setup_response = json.loads(self.ws.recv())
124
- print(f"Setup response: {setup_response}")
125
- except websockets.exceptions.WebSocketException as e:
126
- print(f"WebSocket connection failed: {str(e)}")
127
- self.ws = None
128
- except Exception as e:
129
- print(f"Setup failed: {str(e)}")
130
- self.ws = None
131
-
132
- def receive(self, frame: tuple[int, np.ndarray]) -> None:
133
- """Empfängt Audio-/Videodaten, kodiert sie und sendet sie an die Gemini API."""
134
- try:
135
- if not self.ws:
136
- self._initialize_websocket()
137
- if not self.ws: # Überprüfen, ob die Verbindung erfolgreich ist
138
- print("WebSocket-Verbindung konnte nicht hergestellt werden.")
139
- return # Frühzeitiger Rückkehr, wenn die Verbindung fehlschlägt
140
-
141
- sample_rate, array = frame
142
- message = {"realtimeInput": {"mediaChunks": []}}
143
-
144
- if sample_rate > 0 and array is not None:
145
- array = array.squeeze()
146
- audio_data = self.audio_processor.encode_audio(array, self.output_sample_rate)
147
- message["realtimeInput"]["mediaChunks"].append({
148
- "mimeType": f"audio/pcm;rate={self.output_sample_rate}",
149
- "data": audio_data["realtimeInput"]["mediaChunks"][0]["data"],
150
- })
151
-
152
- if message["realtimeInput"]["mediaChunks"]:
153
- self.ws.send(json.dumps(message))
154
- except Exception as e:
155
- print(f"Fehler beim Empfangen: {str(e)}")
156
- if self.ws:
157
- self.ws.close()
158
- self.ws = None
159
-
160
-
161
- def _process_server_content(self, content):
162
- """Processes audio output data from the WebSocket response."""
163
- for part in content.get("parts", []):
164
- data = part.get("inlineData", {}).get("data", "")
165
- if data:
166
- audio_array = self.audio_processor.process_audio_response(data)
167
- if self.all_output_data is None:
168
- self.all_output_data = audio_array
169
- else:
170
- self.all_output_data = np.concatenate((self.all_output_data, audio_array))
171
-
172
- while self.all_output_data.shape[-1] >= self.output_frame_size:
173
- yield (self.output_sample_rate, self.all_output_data[: self.output_frame_size].reshape(1, -1))
174
- self.all_output_data = self.all_output_data[self.output_frame_size :]
175
-
176
- def generator(self):
177
- """Generates audio output from the WebSocket stream."""
178
- while True:
179
- if not self.ws:
180
- print("WebSocket not connected")
181
- yield None
182
- continue
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  try:
185
- message = self.ws.recv(timeout=30)
186
- msg = json.loads(message)
187
- if "serverContent" in msg:
188
- content = msg["serverContent"].get("modelTurn", {})
189
- yield from self._process_server_content(content)
190
- except TimeoutError:
191
- print("Timeout waiting for server response")
192
- yield None
193
  except Exception as e:
194
- yield None
195
-
196
- def emit(self) -> tuple[int, np.ndarray] | None:
197
- """Emits the next audio chunk from the generator."""
198
- if not self.ws:
199
- return None
200
- if not hasattr(self, "_generator"):
201
- self._generator = self.generator()
202
- try:
203
- return next(self._generator)
204
- except StopIteration:
205
- self.reset()
206
- return None
207
-
208
- def reset(self) -> None:
209
- """Resets the generator and output data."""
210
- if hasattr(self, "_generator"):
211
- delattr(self, "_generator")
212
- self.all_output_data = None
213
 
214
- def shutdown(self) -> None:
215
- """Closes the WebSocket connection."""
216
- if self.ws:
217
- self.ws.close()
218
-
219
- def check_connection(self):
220
- """Checks if the WebSocket connection is active."""
221
- try:
222
- if not self.ws or self.ws.closed:
223
- self._initialize_websocket()
224
- return True
225
- except Exception as e:
226
- print(f"Connection check failed: {str(e)}")
227
- return False
228
-
229
- # Main Gradio Interface
230
- def registry(
231
- name: str,
232
- token: str | None = None,
233
- **kwargs
234
  ):
235
  """Sets up and returns the Gradio interface."""
236
- api_key = token or os.environ.get(KEY_NAME)
237
- if not api_key:
238
- raise ValueError(f"{KEY_NAME} environment variable is not set.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  interface = gr.Blocks()
241
  with interface:
@@ -248,21 +269,32 @@ def registry(
248
  </div>
249
  """
250
  )
251
- gemini_handler = GeminiHandler()
252
  with gr.Row():
253
- audio = WebRTC(label="Voice Chat", modality="audio", mode="send-receive")
254
-
255
- audio.stream(
256
- gemini_handler,
257
- inputs=[audio],
258
- outputs=[audio],
259
- time_limit=600,
260
- concurrency_limit=10
 
 
 
 
 
 
 
 
261
  )
 
262
  return interface
263
 
264
  # Launch the Gradio interface
265
- gr.load(
266
- name='gemini-2.0-flash-exp',
267
- src=registry,
268
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import base64
3
  import json
4
  import numpy as np
5
  import gradio as gr
6
+ # import websockets.sync.client # No longer needed with FastRTC
7
+ from fastrtc import (
8
+ PeerConnection,
9
+ DataChannel,
10
+ MediaStreamTrack,
11
+ AudioFrame,
12
+ VideoFrame,
13
+ ) # Import FastRTC components
14
+ from aiortc.contrib.media import MediaPlayer, MediaRelay
15
+ import asyncio
16
 
17
  __version__ = "0.0.3"
18
 
19
+ # KEY_NAME = "AIzaSyCWPviRPxj8IMLaijLGbRIsio3dO2rp3rU" # Best practice: Keep API keys out of the main code. Use environment variables.
20
+
21
 
22
  # Configuration and Utilities
23
  class GeminiConfig:
24
  """Configuration settings for Gemini API."""
25
+
26
  def __init__(self):
27
+ self.api_key = os.environ.get("GEMINI_API_KEY") # Use a more descriptive name
28
+ if not self.api_key:
29
+ raise ValueError("GEMINI_API_KEY environment variable is not set.")
30
  self.host = "generativelanguage.googleapis.com"
31
  self.model = "models/gemini-2.0-flash-exp"
32
+ # FastRTC doesn't use WebSockets directly in the same way. We'll handle the API calls differently.
33
+ self.base_url = f"https://{self.host}/v1alpha/{self.model}:streamGenerateContent?key={self.api_key}"
34
+
35
 
36
  class AudioProcessor:
37
  """Handles encoding and decoding of audio data."""
38
+
39
  @staticmethod
40
+ def encode_audio(data: np.ndarray, sample_rate: int) -> str:
41
  """Encodes audio data to base64."""
42
+ # Ensure data is in the correct format (int16)
43
+ if data.dtype != np.int16:
44
+ data = data.astype(np.int16)
45
  encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
46
+ return encoded
 
 
 
 
 
 
 
 
 
47
 
48
  @staticmethod
49
+ def process_audio_response(data: str) -> np.ndarray:
50
  """Decodes audio data from base64."""
51
  audio_data = base64.b64decode(data)
52
  return np.frombuffer(audio_data, dtype=np.int16)
53
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # We don't need a StreamHandler in the same way with FastRTC. We'll handle streaming directly.
56
+ class GeminiHandler:
57
+ """Handles interactions with the Gemini API."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def __init__(self, output_sample_rate=24000, output_frame_size=480):
60
+ self.config = GeminiConfig()
61
+ self.audio_processor = AudioProcessor()
62
+ self.output_sample_rate = output_sample_rate
63
+ self.output_frame_size = output_frame_size
64
+ self.all_output_data = None
65
+ self.pc = None # PeerConnection
66
+ self.dc = None # DataChannel
67
+ self.audio_track = None
68
+ self._audio_buffer = []
69
+ self.relay = MediaRelay()
70
+
71
+ async def _send_audio_to_gemini(self, encoded_audio: str):
72
+ """Sends audio data to the Gemini API and processes the response."""
73
+ headers = {"Content-Type": "application/json"}
74
+ payload = {
75
+ "contents": [
76
+ {
77
+ "parts": [
78
+ {
79
+ "text": "Respond to the audio with audio."
80
+ }, # Initial prompt, can be adjusted
81
+ {"inline_data": {"mime_type": "audio/pcm;rate=24000", "data": encoded_audio}},
82
+ ]
83
+ }
84
+ ]
85
+ }
86
+ # Use aiohttp for asynchronous HTTP requests
87
+ import aiohttp
88
+
89
+ async with aiohttp.ClientSession() as session:
90
+ async with session.post(
91
+ self.config.base_url, headers=headers, data=json.dumps(payload)
92
+ ) as response:
93
+ if response.status != 200:
94
+ print(f"Error: Gemini API returned status {response.status}")
95
+ print(await response.text())
96
+ return
97
+
98
+ async for line in response.content:
99
+ try:
100
+ line = line.strip()
101
+ if not line:
102
+ continue
103
+ # Responses are chunked, often with multiple JSON objects per chunk. Handle that.
104
+ for chunk in line.decode("utf-8").split("\n"):
105
+ if not chunk.strip():
106
+ continue
107
+ try:
108
+ data = json.loads(chunk)
109
+ except json.JSONDecodeError:
110
+ print(f"JSONDecodeError: {chunk}")
111
+ continue
112
+
113
+ if "candidates" in data:
114
+ for candidate in data["candidates"]:
115
+ for part in candidate.get("content", {}).get("parts", []):
116
+ if "inlineData" in part:
117
+ audio_data = part["inlineData"].get("data", "")
118
+ if audio_data:
119
+ await self._process_server_audio(audio_data)
120
+
121
+ except Exception as e:
122
+ print(f"Error processing response chunk: {e}")
123
+
124
+ async def _process_server_audio(self, audio_data: str):
125
+ """Processes and buffers audio data received from the server."""
126
+ audio_array = self.audio_processor.process_audio_response(audio_data)
127
+ if self.all_output_data is None:
128
+ self.all_output_data = audio_array
129
+ else:
130
+ self.all_output_data = np.concatenate((self.all_output_data, audio_array))
131
+
132
+ while self.all_output_data.shape[-1] >= self.output_frame_size:
133
+ frame = AudioFrame(
134
+ samples=self.output_frame_size,
135
+ sample_rate=self.output_sample_rate,
136
+ layout="mono", # mono channel
137
+ data=self.all_output_data[: self.output_frame_size].tobytes()
138
+ )
139
+ self.all_output_data = self.all_output_data[self.output_frame_size:]
140
+ if self.audio_track:
141
+ await self.audio_track.emit(frame)
142
+
143
+
144
+ async def on_track(self, track):
145
+ """Handles incoming media tracks."""
146
+ print(f"Track received: {track.kind}")
147
+ if track.kind == "audio":
148
+ self.audio_track = track # Store the audio track
149
+
150
+ @track.on("frame")
151
+ async def on_frame(frame):
152
+ # Process received audio frames
153
+ if isinstance(frame, AudioFrame):
154
+ try:
155
+ # Convert the frame data to a NumPy array
156
+ audio_data = np.frombuffer(frame.data, dtype=np.int16)
157
+ # Encode the audio and send it to Gemini
158
+ encoded_audio = self.audio_processor.encode_audio(
159
+ audio_data, frame.sample_rate
160
+ ) # Pass sample rate
161
+ await self._send_audio_to_gemini(encoded_audio)
162
+ except Exception as e:
163
+ print(f"Error processing audio frame: {e}")
164
+
165
+ async def on_datachannel(self, channel):
166
+ """Handles data channel events (not used in this example, but good practice)."""
167
+ self.dc = channel
168
+ print("Data channel created")
169
+
170
+ @channel.on("message")
171
+ async def on_message(message):
172
+ print(f"Received message: {message}")
173
+
174
+ async def connect(self):
175
+ """Establishes the PeerConnection."""
176
+ self.pc = PeerConnection()
177
+ self.pc.on("track", self.on_track)
178
+ self.pc.on("datachannel", self.on_datachannel)
179
+
180
+ # Create a local audio track to send data
181
+ self.local_audio_player = MediaPlayer("default", format="avfoundation", options={"channels": "1", "sample_rate": str(self.output_sample_rate)})
182
+ self.local_audio = self.relay.subscribe(self.local_audio_player.audio)
183
+ self.pc.addTrack(self.local_audio)
184
+
185
+ # Add a data channel (optional, but good practice)
186
+ self.dc = self.pc.createDataChannel("data")
187
+
188
+ # Create an offer and set local description
189
+ offer = await self.pc.createOffer()
190
+ await self.pc.setLocalDescription(offer)
191
+ print("PeerConnection established")
192
+ return self.pc.localDescription
193
+
194
+ async def set_remote_description(self, sdp, type):
195
+ """Sets the remote description."""
196
+ from aiortc import RTCSessionDescription
197
+
198
+ await self.pc.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=type))
199
+ print("Remote description set")
200
+
201
+ if self.pc.remoteDescription.type == "offer":
202
+ answer = await self.pc.createAnswer()
203
+ await self.pc.setLocalDescription(answer)
204
+ return self.pc.localDescription
205
+
206
+ async def add_ice_candidate(self, candidate, sdpMid, sdpMLineIndex):
207
+ """Adds an ICE candidate."""
208
+ from aiortc import RTCIceCandidate
209
+
210
+ if candidate:
211
  try:
212
+ ice_candidate = RTCIceCandidate(
213
+ candidate=candidate, sdpMid=sdpMid, sdpMLineIndex=sdpMLineIndex
214
+ )
215
+ await self.pc.addIceCandidate(ice_candidate)
216
+ print("ICE candidate added")
 
 
 
217
  except Exception as e:
218
+ print(f"Error adding ICE candidate: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ def shutdown(self):
221
+ """Closes the PeerConnection."""
222
+ if self.pc:
223
+ asyncio.create_task(self.pc.close()) # Close in the background
224
+ self.pc = None
225
+ print("PeerConnection closed")
226
+
227
+
228
+ # Gradio Interface
229
+ async def registry(
230
+ name: str,
231
+ token: str | None = None,
232
+ **kwargs,
 
 
 
 
 
 
 
233
  ):
234
  """Sets up and returns the Gradio interface."""
235
+ gemini_handler = GeminiHandler()
236
+
237
+ async def connect_webrtc(sdp, type, candidates):
238
+ """Connects to the WebRTC client and handles ICE candidates."""
239
+ if gemini_handler.pc is None:
240
+ local_description = await gemini_handler.connect()
241
+ if local_description:
242
+ yield json.dumps(
243
+ {
244
+ "sdp": local_description.sdp,
245
+ "type": local_description.type,
246
+ "candidates": [],
247
+ }
248
+ ) # Return initial SDP
249
+ if sdp and type:
250
+ answer = await gemini_handler.set_remote_description(sdp, type)
251
+ if answer:
252
+ yield json.dumps({"sdp": answer.sdp, "type": answer.type, "candidates": []})
253
+
254
+ for candidate in candidates:
255
+ if candidate and candidate.get("candidate"):
256
+ await gemini_handler.add_ice_candidate(
257
+ candidate["candidate"], candidate.get("sdpMid"), candidate.get("sdpMLineIndex")
258
+ )
259
+ yield json.dumps({"sdp": "", "type": "", "candidates": []}) # Signal completion
260
 
261
  interface = gr.Blocks()
262
  with interface:
 
269
  </div>
270
  """
271
  )
 
272
  with gr.Row():
273
+ webrtc_out = gr.JSON(label="WebRTC JSON")
274
+
275
+ # Use the built-in WebRTC component, but without automatic streaming.
276
+ webrtc = gr.WebRTC(
277
+ value={"sdp": "", "type": "", "candidates": []},
278
+ interactive=True,
279
+ label="Voice Chat",
280
+ )
281
+
282
+ connect_button = gr.Button("Connect")
283
+ connect_button.click(
284
+ connect_webrtc,
285
+ inputs=[
286
+ webrtc
287
+ ], # Pass the WebRTC component's value (SDP, type, candidates)
288
+ outputs=[webrtc_out], # show the webrtc connection data
289
  )
290
+
291
  return interface
292
 
293
  # Launch the Gradio interface
294
+ async def main():
295
+ interface = await registry(name="gemini-2.0-flash-exp")
296
+ interface.queue() # Enable queuing for better concurrency
297
+ await interface.launch()
298
+
299
+ if __name__ == "__main__":
300
+ asyncio.run(main())