mgokg commited on
Commit
cfcda0e
·
verified ·
1 Parent(s): c3cd47d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -218
app.py CHANGED
@@ -15,15 +15,10 @@ from fastrtc import (
15
  get_cloudflare_turn_credentials_async,
16
  wait_for_item,
17
  )
18
- from google import genai # Assuming this is from google-generativeai or compatible
19
- from google.genai import types #import Tool, FunctionDeclaration, ToolConfig, Part
20
- #from google.generativeai.types import Tool, FunctionDeclaration, ToolConfig, Part # For Tool calling
21
  from gradio.utils import get_space
22
  from PIL import Image
23
 
24
- # For Google Search
25
- from googlesearch import search as google_search_engine
26
-
27
  load_dotenv()
28
 
29
 
@@ -44,42 +39,6 @@ def encode_image(data: np.ndarray) -> dict:
44
  return {"mime_type": "image/jpeg", "data": base64_str}
45
 
46
 
47
- def perform_google_search(query: str, num_results: int = 3) -> dict:
48
- """
49
- Performs a Google search and returns a summary of results.
50
- Args:
51
- query: The search query.
52
- num_results: The number of results to fetch.
53
- Returns:
54
- A dictionary suitable for Gemini's function response, containing
55
- either a 'summary' of results or an 'error' message.
56
- """
57
- print(f"Performing Google search for: '{query}'...")
58
- try:
59
- search_results_links = []
60
- # Using a loop to get a specific number of results as googlesearch is a generator
61
- count = 0
62
- for url in google_search_engine(query, num_results=num_results, stop=num_results, pause=1.0):
63
- search_results_links.append(url)
64
- count += 1
65
- if count >= num_results:
66
- break
67
-
68
- if not search_results_links:
69
- return {"summary": "No direct results found on the web. You could try rephrasing your search."}
70
-
71
- # Prepare a summary for Gemini
72
- summary_text = "Found the following links based on your search:\n" + "\n".join(search_results_links)
73
- print(f"Search results: {summary_text}")
74
- return {"summary": summary_text}
75
-
76
- except Exception as e:
77
- print(f"Google search error: {e}")
78
- if "HTTP Error 429" in str(e) or "429" in str(e): # Handle rate limiting
79
- return {"error": "The search service is temporarily busy (rate limited). Please try again in a moment."}
80
- return {"error": f"An error occurred during the search: {str(e)}"}
81
-
82
-
83
  class GeminiHandler(AsyncAudioVideoStreamHandler):
84
  def __init__(
85
  self,
@@ -95,127 +54,41 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
95
  self.last_frame_time = 0
96
  self.quit = asyncio.Event()
97
 
98
- # Define the Google Search tool for Gemini
99
- self.google_search_tool_declaration = FunctionDeclaration(
100
- name="perform_google_search_for_user", # Name Gemini will use
101
- description="Performs a Google search for a given query and returns a summary of the top results. Use this for general web searches.",
102
- parameters={
103
- "type": "OBJECT",
104
- "properties": {
105
- "query": {"type": "STRING", "description": "The search query to look up on Google."}
106
- },
107
- "required": ["query"],
108
- },
109
- )
110
- self.gemini_tools = [Tool(function_declarations=[self.google_search_tool_declaration])]
111
-
112
-
113
  def copy(self) -> "GeminiHandler":
114
  return GeminiHandler()
115
 
116
  async def start_up(self):
117
- # Ensure GEMINI_API_KEY is set in your .env file or environment
118
- api_key = os.getenv("GEMINI_API_KEY")
119
- if not api_key:
120
- raise ValueError("GEMINI_API_KEY not found in environment variables.")
121
-
122
- # Using google.generativeai's standard client setup if `from google import genai` provides it
123
- # If `genai.Client` is from a different library, this part might need adjustment.
124
- # Assuming `genai.Client` is akin to `google.generativeai.GenerativeServiceAsyncClient`
125
- # or a wrapper that `google-generativeai` provides.
126
-
127
- # The original code uses `genai.Client(...)`. Let's try to adapt that.
128
- # Removing `http_options` to use default (likely v1beta for tools)
129
- # If `v1alpha` is strictly required by your `genai.Client`, tool support might be limited.
130
- client = genai.Client(api_key=api_key)
131
-
132
- # Configure Gemini for audio response and tool usage
133
- config = {
134
- "response_modalities": ["AUDIO"],
135
- "tool_config": ToolConfig(function_declarations=[self.google_search_tool_declaration])
136
- }
137
-
138
- # Using a model known for good tool use and speed
139
- model_name = "gemini-1.5-flash-latest"
140
- print(f"Connecting to Gemini model: {model_name} with search tool enabled.")
141
-
142
- try:
143
- async with client.aio.live.connect(
144
- model=model_name,
145
- config=config,
146
- ) as session:
147
- self.session = session
148
- print("Gemini session started successfully. You can now speak or ask to search.")
149
- while not self.quit.is_set():
150
- current_turn = self.session.receive() # This gets a LiveTurn object
151
-
152
- try:
153
- # First, process any incoming audio chunks from Gemini for this turn
154
- async for response_chunk in current_turn: # Iterates over LiveResponseChunk
155
- if response_chunk.data: # This is audio data from Gemini
156
- audio = np.frombuffer(response_chunk.data, dtype=np.int16).reshape(1, -1)
157
- self.audio_queue.put_nowait(audio)
158
-
159
- # After processing all chunks, check if Gemini requested a tool call in this turn
160
- if current_turn.tool_code and current_turn.tool_code.function_call:
161
- fc = current_turn.tool_code.function_call
162
- tool_name = fc.name
163
- tool_args = fc.args
164
-
165
- if tool_name == "perform_google_search_for_user":
166
- query = tool_args.get("query")
167
- if not query:
168
- print("Error: 'query' argument missing for search tool.")
169
- tool_response_part = Part.from_function_response(
170
- name=tool_name,
171
- response={"error": "Missing 'query' argument for search."}
172
- )
173
- else:
174
- print(f"Gemini requested search: '{query}'")
175
- # Run the blocking search function in a separate thread
176
- search_result_dict = await asyncio.to_thread(perform_google_search, query)
177
-
178
- tool_response_part = Part.from_function_response(
179
- name=tool_name,
180
- response=search_result_dict # Pass the dict directly
181
- )
182
- print(f"Sending search tool response to Gemini: {search_result_dict}")
183
- await self.session.send(input=[tool_response_part])
184
- else:
185
- print(f"Error: Gemini requested unknown tool: {tool_name}")
186
- tool_error_response = Part.from_function_response(
187
- name=tool_name,
188
- response={"error": f"Tool '{tool_name}' is not implemented."}
189
- )
190
- await self.session.send(input=[tool_error_response])
191
-
192
- except websockets.exceptions.ConnectionClosedOK:
193
- print("WebSocket connection closed by Gemini server.")
194
- break
195
- except Exception as e:
196
- print(f"Error processing a turn from Gemini: {e}")
197
- # Decide if to break or continue based on error severity
198
- break # Safest to break on unexpected errors in the loop
199
- except Exception as e:
200
- print(f"Failed to connect or maintain Gemini session: {e}")
201
- # Handle connection errors, API key issues, etc.
202
- finally:
203
- self.quit.set() # Ensure shutdown is triggered if loop exits
204
 
205
  async def video_receive(self, frame: np.ndarray):
206
  self.video_queue.put_nowait(frame)
207
 
208
- if self.session and not self.session.closed:
209
- if time.time() - self.last_frame_time > 1: # Send video frame every 1 second
 
 
210
  self.last_frame_time = time.time()
211
- try:
212
- await self.session.send(input=encode_image(frame))
213
- # latest_args[0] is webrtc component, latest_args[1] is image_input
214
- if self.latest_args and len(self.latest_args) > 1 and self.latest_args[1] is not None:
215
- await self.session.send(input=encode_image(self.latest_args[1]))
216
- except Exception as e:
217
- print(f"Error sending video/image to Gemini: {e}")
218
-
219
 
220
  async def video_emit(self):
221
  frame = await wait_for_item(self.video_queue, 0.01)
@@ -224,50 +97,40 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
224
  else:
225
  return np.zeros((100, 100, 3), dtype=np.uint8)
226
 
227
- async def receive(self, frame: tuple[int, np.ndarray]) -> None: # Receives audio from user
228
  _, array = frame
229
  array = array.squeeze()
230
  audio_message = encode_audio(array)
231
- if self.session and not self.session.closed:
232
- try:
233
- await self.session.send(input=audio_message)
234
- except Exception as e:
235
- print(f"Error sending audio to Gemini: {e}")
236
-
237
 
238
- async def emit(self): # Emits audio from Gemini to user
239
  array = await wait_for_item(self.audio_queue, 0.01)
240
  if array is not None:
241
  return (self.output_sample_rate, array)
242
- return None # Return None if no audio, as per fastrtc expectation
243
 
244
  async def shutdown(self) -> None:
245
- print("Shutting down GeminiHandler...")
246
- self.quit.set()
247
- if self.session and not self.session.closed:
248
- try:
249
- await self.session.close()
250
- print("Gemini session closed.")
251
- except Exception as e:
252
- print(f"Error closing Gemini session: {e}")
253
- self.session = None
254
 
255
 
256
- # --- Gradio UI (largely unchanged) ---
257
- stream = Stream( # This Stream object is for the deprecated gr.Interface way
258
  handler=GeminiHandler(),
259
  modality="audio-video",
260
  mode="send-receive",
261
  rtc_configuration=get_cloudflare_turn_credentials_async,
262
  time_limit=180 if get_space() else None,
263
  additional_inputs=[
264
- gr.Image(label="Optional Image Input", type="numpy", sources=["upload", "clipboard"])
265
  ],
266
  ui_args={
267
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
268
  "pulse_color": "rgb(255, 255, 255)",
269
  "icon_button_color": "rgb(255, 255, 255)",
270
- "title": "Gemini Audio Video Chat + Search",
271
  },
272
  )
273
 
@@ -275,28 +138,22 @@ css = """
275
  #video-source {max-width: 500px !important; max-height: 500px !important;}
276
  """
277
 
278
- with gr.Blocks(css=css, title="Gemini AV Chat + Search") as demo:
279
  gr.HTML(
280
  """
281
  <div>
282
  <center>
283
- <h1>Gen AI Voice Chat with Google Search</h1>
284
- <p>Real-time audio + video streaming, with integrated Google Search via Gemini.</p>
285
- <p><small>Note: Search uses web scraping, which may be rate-limited or unreliable for heavy use. For production, use official APIs.</small></p>
286
  </center>
287
  </div>
288
  """
289
  )
290
- # Additional input for an image (as in original `stream` object)
291
- # This needs to be passed to the handler if you want it to be used.
292
- # The `webrtc.stream` inputs must match what the handler expects in `self.latest_args`
293
- image_input_component = gr.Image(label="Optional Image Input", type="numpy", sources=["upload", "clipboard"])
294
-
295
  with gr.Row() as row:
296
  with gr.Column():
297
  webrtc = WebRTC(
298
  label="Video Chat",
299
- modality="audio-video", # This component handles both audio and video from user
300
  mode="send-receive",
301
  elem_id="video-source",
302
  rtc_configuration=get_cloudflare_turn_credentials_async,
@@ -304,39 +161,26 @@ with gr.Blocks(css=css, title="Gemini AV Chat + Search") as demo:
304
  pulse_color="rgb(255, 255, 255)",
305
  icon_button_color="rgb(255, 255, 255)",
306
  )
307
- # The image_input_component is now defined above the row for clarity
 
 
 
 
 
 
 
 
 
 
 
308
 
309
- # The WebRTC component itself is the primary input for audio/video.
310
- # The additional image input needs to be correctly wired.
311
- # `webrtc.stream` will pass its inputs to the handler's `self.latest_args`.
312
- # If webrtc is input[0] and image_input_component is input[1], then
313
- # in GeminiHandler, self.latest_args[0] is webrtc data, self.latest_args[1] is image_input_component data.
314
- webrtc.stream(
315
- handler_class=GeminiHandler, # Pass the class, not an instance for gr.Blocks
316
- inputs=[webrtc, image_input_component], # webrtc is audio/video, image_input_component for the static image
317
- outputs=[webrtc], # webrtc for audio/video output from Gemini
318
- time_limit=180 if get_space() else None,
319
- concurrency_limit=2 if get_space() else None,
320
- )
321
 
322
- # The `stream.ui = demo` line might be for an older way of launching.
323
- # For gr.Blocks, `demo.launch()` is standard.
324
- # If `fastrtc.Stream` is meant to wrap `gr.Blocks`, its usage might differ.
325
- # Assuming standard Gradio launch:
326
 
327
  if __name__ == "__main__":
328
- if os.getenv("GEMINI_API_KEY") is None:
329
- print("WARNING: GEMINI_API_KEY environment variable not set. The application may not work.")
330
-
331
- # The original code had `stream.ui.launch()`. If `stream` is a `fastrtc.Stream` object,
332
- # and it's meant to manage the Gradio app, then that's correct.
333
- # If `demo` is the primary Gradio interface, then `demo.launch()` is used.
334
- # Let's stick to the original pattern if `fastrtc.Stream` requires it.
335
- # stream.ui = demo # This assignment might be specific to how fastrtc integrates
336
-
337
- # The original structure seems to use fastrtc.Stream to build the UI, then replace its UI with gr.Blocks.
338
- # This is a bit unusual. Let's check if fastrtc.Stream is used by WebRTC.
339
- # The WebRTC.stream call implies it handles its own streaming logic with the handler.
340
- # So `demo` should be the main UI.
341
-
342
- demo.launch(server_port=7860, debug=True) # Added debug=True for development
 
15
  get_cloudflare_turn_credentials_async,
16
  wait_for_item,
17
  )
18
+ from google import genai
 
 
19
  from gradio.utils import get_space
20
  from PIL import Image
21
 
 
 
 
22
  load_dotenv()
23
 
24
 
 
39
  return {"mime_type": "image/jpeg", "data": base64_str}
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class GeminiHandler(AsyncAudioVideoStreamHandler):
43
  def __init__(
44
  self,
 
54
  self.last_frame_time = 0
55
  self.quit = asyncio.Event()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def copy(self) -> "GeminiHandler":
58
  return GeminiHandler()
59
 
60
  async def start_up(self):
61
+ client = genai.Client(
62
+ api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
63
+ )
64
+ config = {"response_modalities": ["AUDIO"]}
65
+ async with client.aio.live.connect(
66
+ model="gemini-2.0-flash-exp",
67
+ config=config, # type: ignore
68
+ ) as session:
69
+ self.session = session
70
+ while not self.quit.is_set():
71
+ turn = self.session.receive()
72
+ try:
73
+ async for response in turn:
74
+ if data := response.data:
75
+ audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
76
+ self.audio_queue.put_nowait(audio)
77
+ except websockets.exceptions.ConnectionClosedOK:
78
+ print("connection closed")
79
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  async def video_receive(self, frame: np.ndarray):
82
  self.video_queue.put_nowait(frame)
83
 
84
+ if self.session:
85
+ # send image every 1 second
86
+ print(time.time() - self.last_frame_time)
87
+ if time.time() - self.last_frame_time > 1:
88
  self.last_frame_time = time.time()
89
+ await self.session.send(input=encode_image(frame))
90
+ if self.latest_args[1] is not None:
91
+ await self.session.send(input=encode_image(self.latest_args[1]))
 
 
 
 
 
92
 
93
  async def video_emit(self):
94
  frame = await wait_for_item(self.video_queue, 0.01)
 
97
  else:
98
  return np.zeros((100, 100, 3), dtype=np.uint8)
99
 
100
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
101
  _, array = frame
102
  array = array.squeeze()
103
  audio_message = encode_audio(array)
104
+ if self.session:
105
+ await self.session.send(input=audio_message)
 
 
 
 
106
 
107
+ async def emit(self):
108
  array = await wait_for_item(self.audio_queue, 0.01)
109
  if array is not None:
110
  return (self.output_sample_rate, array)
111
+ return array
112
 
113
  async def shutdown(self) -> None:
114
+ if self.session:
115
+ self.quit.set()
116
+ await self.session.close()
117
+ self.quit.clear()
 
 
 
 
 
118
 
119
 
120
+ stream = Stream(
 
121
  handler=GeminiHandler(),
122
  modality="audio-video",
123
  mode="send-receive",
124
  rtc_configuration=get_cloudflare_turn_credentials_async,
125
  time_limit=180 if get_space() else None,
126
  additional_inputs=[
127
+ gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
128
  ],
129
  ui_args={
130
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
131
  "pulse_color": "rgb(255, 255, 255)",
132
  "icon_button_color": "rgb(255, 255, 255)",
133
+ "title": "Gemini Audio Video Chat",
134
  },
135
  )
136
 
 
138
  #video-source {max-width: 500px !important; max-height: 500px !important;}
139
  """
140
 
141
+ with gr.Blocks(css=css) as demo:
142
  gr.HTML(
143
  """
144
  <div>
145
  <center>
146
+ <h1>Gen AI Voice Chat</h1>
147
+ <p>real-time audio + video streaming</p>
 
148
  </center>
149
  </div>
150
  """
151
  )
 
 
 
 
 
152
  with gr.Row() as row:
153
  with gr.Column():
154
  webrtc = WebRTC(
155
  label="Video Chat",
156
+ modality="audio-video",
157
  mode="send-receive",
158
  elem_id="video-source",
159
  rtc_configuration=get_cloudflare_turn_credentials_async,
 
161
  pulse_color="rgb(255, 255, 255)",
162
  icon_button_color="rgb(255, 255, 255)",
163
  )
164
+ #with gr.Column():
165
+ #image_input = gr.Image(
166
+ #label="Image", type="numpy", sources=["upload", "clipboard"]
167
+ #)
168
+
169
+ webrtc.stream(
170
+ GeminiHandler(),
171
+ inputs=[webrtc],
172
+ outputs=[webrtc],
173
+ time_limit=180 if get_space() else None,
174
+ concurrency_limit=2 if get_space() else None,
175
+ )
176
 
177
+ stream.ui = demo
 
 
 
 
 
 
 
 
 
 
 
178
 
 
 
 
 
179
 
180
  if __name__ == "__main__":
181
+ if (mode := os.getenv("MODE")) == "UI":
182
+ stream.ui.launch(server_port=7860)
183
+ elif mode == "PHONE":
184
+ raise ValueError("Phone mode not supported for this demo")
185
+ else:
186
+ stream.ui.launch(server_port=7860)