mgokg commited on
Commit
734afbb
·
verified ·
1 Parent(s): 57775e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -62
app.py CHANGED
@@ -15,10 +15,14 @@ from fastrtc import (
15
  get_cloudflare_turn_credentials_async,
16
  wait_for_item,
17
  )
18
- from google import genai
 
19
  from gradio.utils import get_space
20
  from PIL import Image
21
 
 
 
 
22
  load_dotenv()
23
 
24
 
@@ -39,6 +43,42 @@ def encode_image(data: np.ndarray) -> dict:
39
  return {"mime_type": "image/jpeg", "data": base64_str}
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class GeminiHandler(AsyncAudioVideoStreamHandler):
43
  def __init__(
44
  self,
@@ -54,41 +94,127 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
54
  self.last_frame_time = 0
55
  self.quit = asyncio.Event()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def copy(self) -> "GeminiHandler":
58
  return GeminiHandler()
59
 
60
  async def start_up(self):
61
- client = genai.Client(
62
- api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
63
- )
64
- config = {"response_modalities": ["AUDIO"]}
65
- async with client.aio.live.connect(
66
- model="gemini-2.0-flash-exp",
67
- config=config, # type: ignore
68
- ) as session:
69
- self.session = session
70
- while not self.quit.is_set():
71
- turn = self.session.receive()
72
- try:
73
- async for response in turn:
74
- if data := response.data:
75
- audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
76
- self.audio_queue.put_nowait(audio)
77
- except websockets.exceptions.ConnectionClosedOK:
78
- print("connection closed")
79
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  async def video_receive(self, frame: np.ndarray):
82
  self.video_queue.put_nowait(frame)
83
 
84
- if self.session:
85
- # send image every 1 second
86
- print(time.time() - self.last_frame_time)
87
- if time.time() - self.last_frame_time > 1:
88
  self.last_frame_time = time.time()
89
- await self.session.send(input=encode_image(frame))
90
- if self.latest_args[1] is not None:
91
- await self.session.send(input=encode_image(self.latest_args[1]))
 
 
 
 
 
92
 
93
  async def video_emit(self):
94
  frame = await wait_for_item(self.video_queue, 0.01)
@@ -97,40 +223,50 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
97
  else:
98
  return np.zeros((100, 100, 3), dtype=np.uint8)
99
 
100
- async def receive(self, frame: tuple[int, np.ndarray]) -> None:
101
  _, array = frame
102
  array = array.squeeze()
103
  audio_message = encode_audio(array)
104
- if self.session:
105
- await self.session.send(input=audio_message)
 
 
 
 
106
 
107
- async def emit(self):
108
  array = await wait_for_item(self.audio_queue, 0.01)
109
  if array is not None:
110
  return (self.output_sample_rate, array)
111
- return array
112
 
113
  async def shutdown(self) -> None:
114
- if self.session:
115
- self.quit.set()
116
- await self.session.close()
117
- self.quit.clear()
 
 
 
 
 
118
 
119
 
120
- stream = Stream(
 
121
  handler=GeminiHandler(),
122
  modality="audio-video",
123
  mode="send-receive",
124
  rtc_configuration=get_cloudflare_turn_credentials_async,
125
  time_limit=180 if get_space() else None,
126
  additional_inputs=[
127
- gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
128
  ],
129
  ui_args={
130
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
131
  "pulse_color": "rgb(255, 255, 255)",
132
  "icon_button_color": "rgb(255, 255, 255)",
133
- "title": "Gemini Audio Video Chat",
134
  },
135
  )
136
 
@@ -138,22 +274,28 @@ css = """
138
  #video-source {max-width: 500px !important; max-height: 500px !important;}
139
  """
140
 
141
- with gr.Blocks(css=css) as demo:
142
  gr.HTML(
143
  """
144
  <div>
145
  <center>
146
- <h1>Gen AI Voice Chat</h1>
147
- <p>real-time audio + video streaming</p>
 
148
  </center>
149
  </div>
150
  """
151
  )
 
 
 
 
 
152
  with gr.Row() as row:
153
  with gr.Column():
154
  webrtc = WebRTC(
155
  label="Video Chat",
156
- modality="audio-video",
157
  mode="send-receive",
158
  elem_id="video-source",
159
  rtc_configuration=get_cloudflare_turn_credentials_async,
@@ -161,26 +303,39 @@ with gr.Blocks(css=css) as demo:
161
  pulse_color="rgb(255, 255, 255)",
162
  icon_button_color="rgb(255, 255, 255)",
163
  )
164
- #with gr.Column():
165
- #image_input = gr.Image(
166
- #label="Image", type="numpy", sources=["upload", "clipboard"]
167
- #)
168
-
169
- webrtc.stream(
170
- GeminiHandler(),
171
- inputs=[webrtc],
172
- outputs=[webrtc],
173
- time_limit=180 if get_space() else None,
174
- concurrency_limit=2 if get_space() else None,
175
- )
176
 
177
- stream.ui = demo
 
 
 
 
 
 
 
 
 
 
 
178
 
 
 
 
 
179
 
180
  if __name__ == "__main__":
181
- if (mode := os.getenv("MODE")) == "UI":
182
- stream.ui.launch(server_port=7860)
183
- elif mode == "PHONE":
184
- raise ValueError("Phone mode not supported for this demo")
185
- else:
186
- stream.ui.launch(server_port=7860)
 
 
 
 
 
 
 
 
 
 
15
  get_cloudflare_turn_credentials_async,
16
  wait_for_item,
17
  )
18
+ from google import genai # Assuming this is from google-generativeai or compatible
19
+ from google.generativeai.types import Tool, FunctionDeclaration, ToolConfig, Part # For Tool calling
20
  from gradio.utils import get_space
21
  from PIL import Image
22
 
23
+ # For Google Search
24
+ from googlesearch import search as google_search_engine
25
+
26
  load_dotenv()
27
 
28
 
 
43
  return {"mime_type": "image/jpeg", "data": base64_str}
44
 
45
 
46
+ def perform_google_search(query: str, num_results: int = 3) -> dict:
47
+ """
48
+ Performs a Google search and returns a summary of results.
49
+ Args:
50
+ query: The search query.
51
+ num_results: The number of results to fetch.
52
+ Returns:
53
+ A dictionary suitable for Gemini's function response, containing
54
+ either a 'summary' of results or an 'error' message.
55
+ """
56
+ print(f"Performing Google search for: '{query}'...")
57
+ try:
58
+ search_results_links = []
59
+ # Using a loop to get a specific number of results as googlesearch is a generator
60
+ count = 0
61
+ for url in google_search_engine(query, num_results=num_results, stop=num_results, pause=1.0):
62
+ search_results_links.append(url)
63
+ count += 1
64
+ if count >= num_results:
65
+ break
66
+
67
+ if not search_results_links:
68
+ return {"summary": "No direct results found on the web. You could try rephrasing your search."}
69
+
70
+ # Prepare a summary for Gemini
71
+ summary_text = "Found the following links based on your search:\n" + "\n".join(search_results_links)
72
+ print(f"Search results: {summary_text}")
73
+ return {"summary": summary_text}
74
+
75
+ except Exception as e:
76
+ print(f"Google search error: {e}")
77
+ if "HTTP Error 429" in str(e) or "429" in str(e): # Handle rate limiting
78
+ return {"error": "The search service is temporarily busy (rate limited). Please try again in a moment."}
79
+ return {"error": f"An error occurred during the search: {str(e)}"}
80
+
81
+
82
  class GeminiHandler(AsyncAudioVideoStreamHandler):
83
  def __init__(
84
  self,
 
94
  self.last_frame_time = 0
95
  self.quit = asyncio.Event()
96
 
97
+ # Define the Google Search tool for Gemini
98
+ self.google_search_tool_declaration = FunctionDeclaration(
99
+ name="perform_google_search_for_user", # Name Gemini will use
100
+ description="Performs a Google search for a given query and returns a summary of the top results. Use this for general web searches.",
101
+ parameters={
102
+ "type": "OBJECT",
103
+ "properties": {
104
+ "query": {"type": "STRING", "description": "The search query to look up on Google."}
105
+ },
106
+ "required": ["query"],
107
+ },
108
+ )
109
+ self.gemini_tools = [Tool(function_declarations=[self.google_search_tool_declaration])]
110
+
111
+
112
  def copy(self) -> "GeminiHandler":
113
  return GeminiHandler()
114
 
115
  async def start_up(self):
116
+ # Ensure GEMINI_API_KEY is set in your .env file or environment
117
+ api_key = os.getenv("GEMINI_API_KEY")
118
+ if not api_key:
119
+ raise ValueError("GEMINI_API_KEY not found in environment variables.")
120
+
121
+ # Using google.generativeai's standard client setup if `from google import genai` provides it
122
+ # If `genai.Client` is from a different library, this part might need adjustment.
123
+ # Assuming `genai.Client` is akin to `google.generativeai.GenerativeServiceAsyncClient`
124
+ # or a wrapper that `google-generativeai` provides.
125
+
126
+ # The original code uses `genai.Client(...)`. Let's try to adapt that.
127
+ # Removing `http_options` to use default (likely v1beta for tools)
128
+ # If `v1alpha` is strictly required by your `genai.Client`, tool support might be limited.
129
+ client = genai.Client(api_key=api_key)
130
+
131
+ # Configure Gemini for audio response and tool usage
132
+ config = {
133
+ "response_modalities": ["AUDIO"],
134
+ "tool_config": ToolConfig(function_declarations=[self.google_search_tool_declaration])
135
+ }
136
+
137
+ # Using a model known for good tool use and speed
138
+ model_name = "gemini-1.5-flash-latest"
139
+ print(f"Connecting to Gemini model: {model_name} with search tool enabled.")
140
+
141
+ try:
142
+ async with client.aio.live.connect(
143
+ model=model_name,
144
+ config=config,
145
+ ) as session:
146
+ self.session = session
147
+ print("Gemini session started successfully. You can now speak or ask to search.")
148
+ while not self.quit.is_set():
149
+ current_turn = self.session.receive() # This gets a LiveTurn object
150
+
151
+ try:
152
+ # First, process any incoming audio chunks from Gemini for this turn
153
+ async for response_chunk in current_turn: # Iterates over LiveResponseChunk
154
+ if response_chunk.data: # This is audio data from Gemini
155
+ audio = np.frombuffer(response_chunk.data, dtype=np.int16).reshape(1, -1)
156
+ self.audio_queue.put_nowait(audio)
157
+
158
+ # After processing all chunks, check if Gemini requested a tool call in this turn
159
+ if current_turn.tool_code and current_turn.tool_code.function_call:
160
+ fc = current_turn.tool_code.function_call
161
+ tool_name = fc.name
162
+ tool_args = fc.args
163
+
164
+ if tool_name == "perform_google_search_for_user":
165
+ query = tool_args.get("query")
166
+ if not query:
167
+ print("Error: 'query' argument missing for search tool.")
168
+ tool_response_part = Part.from_function_response(
169
+ name=tool_name,
170
+ response={"error": "Missing 'query' argument for search."}
171
+ )
172
+ else:
173
+ print(f"Gemini requested search: '{query}'")
174
+ # Run the blocking search function in a separate thread
175
+ search_result_dict = await asyncio.to_thread(perform_google_search, query)
176
+
177
+ tool_response_part = Part.from_function_response(
178
+ name=tool_name,
179
+ response=search_result_dict # Pass the dict directly
180
+ )
181
+ print(f"Sending search tool response to Gemini: {search_result_dict}")
182
+ await self.session.send(input=[tool_response_part])
183
+ else:
184
+ print(f"Error: Gemini requested unknown tool: {tool_name}")
185
+ tool_error_response = Part.from_function_response(
186
+ name=tool_name,
187
+ response={"error": f"Tool '{tool_name}' is not implemented."}
188
+ )
189
+ await self.session.send(input=[tool_error_response])
190
+
191
+ except websockets.exceptions.ConnectionClosedOK:
192
+ print("WebSocket connection closed by Gemini server.")
193
+ break
194
+ except Exception as e:
195
+ print(f"Error processing a turn from Gemini: {e}")
196
+ # Decide if to break or continue based on error severity
197
+ break # Safest to break on unexpected errors in the loop
198
+ except Exception as e:
199
+ print(f"Failed to connect or maintain Gemini session: {e}")
200
+ # Handle connection errors, API key issues, etc.
201
+ finally:
202
+ self.quit.set() # Ensure shutdown is triggered if loop exits
203
 
204
  async def video_receive(self, frame: np.ndarray):
205
  self.video_queue.put_nowait(frame)
206
 
207
+ if self.session and not self.session.closed:
208
+ if time.time() - self.last_frame_time > 1: # Send video frame every 1 second
 
 
209
  self.last_frame_time = time.time()
210
+ try:
211
+ await self.session.send(input=encode_image(frame))
212
+ # latest_args[0] is webrtc component, latest_args[1] is image_input
213
+ if self.latest_args and len(self.latest_args) > 1 and self.latest_args[1] is not None:
214
+ await self.session.send(input=encode_image(self.latest_args[1]))
215
+ except Exception as e:
216
+ print(f"Error sending video/image to Gemini: {e}")
217
+
218
 
219
  async def video_emit(self):
220
  frame = await wait_for_item(self.video_queue, 0.01)
 
223
  else:
224
  return np.zeros((100, 100, 3), dtype=np.uint8)
225
 
226
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None: # Receives audio from user
227
  _, array = frame
228
  array = array.squeeze()
229
  audio_message = encode_audio(array)
230
+ if self.session and not self.session.closed:
231
+ try:
232
+ await self.session.send(input=audio_message)
233
+ except Exception as e:
234
+ print(f"Error sending audio to Gemini: {e}")
235
+
236
 
237
+ async def emit(self): # Emits audio from Gemini to user
238
  array = await wait_for_item(self.audio_queue, 0.01)
239
  if array is not None:
240
  return (self.output_sample_rate, array)
241
+ return None # Return None if no audio, as per fastrtc expectation
242
 
243
  async def shutdown(self) -> None:
244
+ print("Shutting down GeminiHandler...")
245
+ self.quit.set()
246
+ if self.session and not self.session.closed:
247
+ try:
248
+ await self.session.close()
249
+ print("Gemini session closed.")
250
+ except Exception as e:
251
+ print(f"Error closing Gemini session: {e}")
252
+ self.session = None
253
 
254
 
255
+ # --- Gradio UI (largely unchanged) ---
256
+ stream = Stream( # This Stream object is for the deprecated gr.Interface way
257
  handler=GeminiHandler(),
258
  modality="audio-video",
259
  mode="send-receive",
260
  rtc_configuration=get_cloudflare_turn_credentials_async,
261
  time_limit=180 if get_space() else None,
262
  additional_inputs=[
263
+ gr.Image(label="Optional Image Input", type="numpy", sources=["upload", "clipboard"])
264
  ],
265
  ui_args={
266
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
267
  "pulse_color": "rgb(255, 255, 255)",
268
  "icon_button_color": "rgb(255, 255, 255)",
269
+ "title": "Gemini Audio Video Chat + Search",
270
  },
271
  )
272
 
 
274
  #video-source {max-width: 500px !important; max-height: 500px !important;}
275
  """
276
 
277
+ with gr.Blocks(css=css, title="Gemini AV Chat + Search") as demo:
278
  gr.HTML(
279
  """
280
  <div>
281
  <center>
282
+ <h1>Gen AI Voice Chat with Google Search</h1>
283
+ <p>Real-time audio + video streaming, with integrated Google Search via Gemini.</p>
284
+ <p><small>Note: Search uses web scraping, which may be rate-limited or unreliable for heavy use. For production, use official APIs.</small></p>
285
  </center>
286
  </div>
287
  """
288
  )
289
+ # Additional input for an image (as in original `stream` object)
290
+ # This needs to be passed to the handler if you want it to be used.
291
+ # The `webrtc.stream` inputs must match what the handler expects in `self.latest_args`
292
+ image_input_component = gr.Image(label="Optional Image Input", type="numpy", sources=["upload", "clipboard"])
293
+
294
  with gr.Row() as row:
295
  with gr.Column():
296
  webrtc = WebRTC(
297
  label="Video Chat",
298
+ modality="audio-video", # This component handles both audio and video from user
299
  mode="send-receive",
300
  elem_id="video-source",
301
  rtc_configuration=get_cloudflare_turn_credentials_async,
 
303
  pulse_color="rgb(255, 255, 255)",
304
  icon_button_color="rgb(255, 255, 255)",
305
  )
306
+ # The image_input_component is now defined above the row for clarity
 
 
 
 
 
 
 
 
 
 
 
307
 
308
+ # The WebRTC component itself is the primary input for audio/video.
309
+ # The additional image input needs to be correctly wired.
310
+ # `webrtc.stream` will pass its inputs to the handler's `self.latest_args`.
311
+ # If webrtc is input[0] and image_input_component is input[1], then
312
+ # in GeminiHandler, self.latest_args[0] is webrtc data, self.latest_args[1] is image_input_component data.
313
+ webrtc.stream(
314
+ handler_class=GeminiHandler, # Pass the class, not an instance for gr.Blocks
315
+ inputs=[webrtc, image_input_component], # webrtc is audio/video, image_input_component for the static image
316
+ outputs=[webrtc], # webrtc for audio/video output from Gemini
317
+ time_limit=180 if get_space() else None,
318
+ concurrency_limit=2 if get_space() else None,
319
+ )
320
 
321
+ # The `stream.ui = demo` line might be for an older way of launching.
322
+ # For gr.Blocks, `demo.launch()` is standard.
323
+ # If `fastrtc.Stream` is meant to wrap `gr.Blocks`, its usage might differ.
324
+ # Assuming standard Gradio launch:
325
 
326
  if __name__ == "__main__":
327
+ if os.getenv("GEMINI_API_KEY") is None:
328
+ print("WARNING: GEMINI_API_KEY environment variable not set. The application may not work.")
329
+
330
+ # The original code had `stream.ui.launch()`. If `stream` is a `fastrtc.Stream` object,
331
+ # and it's meant to manage the Gradio app, then that's correct.
332
+ # If `demo` is the primary Gradio interface, then `demo.launch()` is used.
333
+ # Let's stick to the original pattern if `fastrtc.Stream` requires it.
334
+ # stream.ui = demo # This assignment might be specific to how fastrtc integrates
335
+
336
+ # The original structure seems to use fastrtc.Stream to build the UI, then replace its UI with gr.Blocks.
337
+ # This is a bit unusual. Let's check if fastrtc.Stream is used by WebRTC.
338
+ # The WebRTC.stream call implies it handles its own streaming logic with the handler.
339
+ # So `demo` should be the main UI.
340
+
341
+ demo.launch(server_port=7860, debug=True) # Added debug=True for development