mgokg commited on
Commit
4380ffd
·
verified ·
1 Parent(s): 1208a72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -29
app.py CHANGED
@@ -3,6 +3,7 @@ import base64
3
  import os
4
  import time
5
  from io import BytesIO
 
6
 
7
  import gradio as gr
8
  import numpy as np
@@ -16,11 +17,72 @@ from fastrtc import (
16
  wait_for_item,
17
  )
18
  from google import genai
 
 
 
 
19
  from gradio.utils import get_space
20
  from PIL import Image
 
21
 
22
  load_dotenv()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def encode_audio(data: np.ndarray) -> dict:
26
  """Encode Audio data to send to the server"""
@@ -53,68 +115,148 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
53
  self.session = None
54
  self.last_frame_time = 0
55
  self.quit = asyncio.Event()
 
56
 
57
  def copy(self) -> "GeminiHandler":
58
  return GeminiHandler()
59
 
60
  async def start_up(self):
61
- client = genai.Client(
62
  api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
63
  )
64
- config = {"response_modalities": ["AUDIO"]}
65
- async with client.aio.live.connect(
66
- model="gemini-2.0-flash-exp",
67
- config=config, # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ) as session:
69
  self.session = session
 
70
  while not self.quit.is_set():
71
  turn = self.session.receive()
72
  try:
73
- async for response in turn:
74
- if data := response.data:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
76
- self.audio_queue.put_nowait(audio)
 
 
 
77
  except websockets.exceptions.ConnectionClosedOK:
78
- print("connection closed")
79
  break
 
 
 
 
 
 
80
 
81
  async def video_receive(self, frame: np.ndarray):
82
  self.video_queue.put_nowait(frame)
83
 
84
  if self.session:
85
- # send image every 1 second
86
- print(time.time() - self.last_frame_time)
87
- if time.time() - self.last_frame_time > 1:
88
- self.last_frame_time = time.time()
 
 
 
 
89
  await self.session.send(input=encode_image(frame))
90
- if self.latest_args[1] is not None:
91
- await self.session.send(input=encode_image(self.latest_args[1]))
 
 
 
 
 
 
 
92
 
93
  async def video_emit(self):
94
  frame = await wait_for_item(self.video_queue, 0.01)
95
  if frame is not None:
96
  return frame
97
  else:
98
- return np.zeros((100, 100, 3), dtype=np.uint8)
99
 
100
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
 
101
  _, array = frame
102
  array = array.squeeze()
103
- audio_message = encode_audio(array)
 
104
  if self.session:
105
- await self.session.send(input=audio_message)
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  async def emit(self):
 
108
  array = await wait_for_item(self.audio_queue, 0.01)
109
  if array is not None:
110
  return (self.output_sample_rate, array)
111
- return array
112
 
113
  async def shutdown(self) -> None:
 
114
  if self.session:
115
  self.quit.set()
116
- await self.session.close()
117
- self.quit.clear()
 
 
 
 
 
 
118
 
119
 
120
  stream = Stream(
@@ -130,7 +272,7 @@ stream = Stream(
130
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
131
  "pulse_color": "rgb(255, 255, 255)",
132
  "icon_button_color": "rgb(255, 255, 255)",
133
- "title": "Gemini Audio Video Chat",
134
  },
135
  )
136
 
@@ -146,10 +288,11 @@ with gr.Blocks(css=css) as demo:
146
  <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
147
  </div>
148
  <div>
149
- <h1>Gen AI SDK Voice Chat</h1>
150
- <p>Speak with Gemini using real-time audio + video streaming</p>
 
151
  <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
152
- <p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
153
  </div>
154
  </div>
155
  """
@@ -168,12 +311,15 @@ with gr.Blocks(css=css) as demo:
168
  )
169
  with gr.Column():
170
  image_input = gr.Image(
171
- label="Image", type="numpy", sources=["upload", "clipboard"]
172
  )
173
 
 
 
 
174
  webrtc.stream(
175
- GeminiHandler(),
176
- inputs=[webrtc, image_input],
177
  outputs=[webrtc],
178
  time_limit=180 if get_space() else None,
179
  concurrency_limit=2 if get_space() else None,
@@ -183,9 +329,16 @@ stream.ui = demo
183
 
184
 
185
  if __name__ == "__main__":
 
 
 
 
 
186
  if (mode := os.getenv("MODE")) == "UI":
187
  stream.ui.launch(server_port=7860)
188
  elif mode == "PHONE":
189
  raise ValueError("Phone mode not supported for this demo")
190
  else:
191
- stream.ui.launch(server_port=7860)
 
 
 
3
  import os
4
  import time
5
  from io import BytesIO
6
+ import functools # Added for to_thread
7
 
8
  import gradio as gr
9
  import numpy as np
 
17
  wait_for_item,
18
  )
19
  from google import genai
20
+ # Ensure genai.protos is accessible for Tool, FunctionDeclaration etc.
21
+ # If not, you might need from google.generativeai.types import Tool, FunctionDeclaration, Schema, Part, Content
22
+ # However, with live.connect using v1alpha, direct proto usage is often needed.
23
+ from google.generativeai import protos # Explicitly import protos
24
  from gradio.utils import get_space
25
  from PIL import Image
26
+ from googleapiclient.discovery import build # Added for Google Search
27
 
28
  load_dotenv()
29
 
30
+ # --- Environment Variables for Google Search ---
31
+ GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
32
+ GOOGLE_CSE_ID = os.getenv("GOOGLE_CSE_ID")
33
+
34
+ # --- Google Search Function ---
35
+ async def perform_google_search_async(query: str, num_results: int = 3) -> str:
36
+ """
37
+ Performs a Google search using the Custom Search API and returns formatted results.
38
+ """
39
+ if not GOOGLE_SEARCH_API_KEY or not GOOGLE_CSE_ID:
40
+ print("Google Search API key or CSE ID not configured.")
41
+ return "Search functionality is not configured."
42
+ try:
43
+ loop = asyncio.get_running_loop()
44
+ # Create a partial function for the blocking call
45
+ partial_search = functools.partial(
46
+ build("customsearch", "v1", developerKey=GOOGLE_SEARCH_API_KEY).cse().list(
47
+ q=query, cx=GOOGLE_CSE_ID, num=num_results
48
+ ).execute
49
+ )
50
+ # Run the blocking call in a thread pool
51
+ res = await loop.run_in_executor(None, partial_search)
52
+
53
+ if 'items' in res and res['items']:
54
+ results = []
55
+ for item in res['items']:
56
+ title = item.get('title', 'N/A')
57
+ link = item.get('link', 'N/A')
58
+ snippet = item.get('snippet', 'N/A').replace("\n", " ")
59
+ results.append(f"Title: {title}\nLink: {link}\nSnippet: {snippet}\n---")
60
+ return "\n".join(results)
61
+ else:
62
+ return "No search results found."
63
+ except Exception as e:
64
+ print(f"Error during Google Search: {e}")
65
+ return f"An error occurred while searching: {str(e)}"
66
+
67
+ # --- Define the Google Search Tool for Gemini ---
68
+ # Using genai.protos directly as LiveSession client might expect raw protos
69
+ google_search_tool = protos.Tool(
70
+ function_declarations=[
71
+ protos.FunctionDeclaration(
72
+ name="perform_google_search",
73
+ description="Performs a Google search for a given query and returns a summary of the top results. Use this for general web searches or finding specific information online.",
74
+ parameters=protos.Schema(
75
+ type=protos.Type.OBJECT,
76
+ properties={
77
+ "query": protos.Schema(type=protos.Type.STRING, description="The search query to use for Google Search."),
78
+ "num_results": protos.Schema(type=protos.Type.NUMBER, description="Optional. Number of search results to return (default is 3).")
79
+ },
80
+ required=["query"]
81
+ )
82
+ )
83
+ ]
84
+ )
85
+
86
 
87
  def encode_audio(data: np.ndarray) -> dict:
88
  """Encode Audio data to send to the server"""
 
115
  self.session = None
116
  self.last_frame_time = 0
117
  self.quit = asyncio.Event()
118
+ self.client = None # Store client
119
 
120
  def copy(self) -> "GeminiHandler":
121
  return GeminiHandler()
122
 
123
  async def start_up(self):
124
+ self.client = genai.Client( # Use self.client
125
  api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
126
  )
127
+
128
+ # Configure Gemini to use the search tool
129
+ # Note: For v1alpha live client, config might be a dict or protos.StreamingConfig
130
+ # protos.ToolConfig and protos.FunctionCallingConfig might be needed for more control
131
+ # e.g. tool_config=protos.ToolConfig(function_calling_config=protos.FunctionCallingConfig(mode=protos.FunctionCallingConfig.Mode.ANY))
132
+
133
+ streaming_config = protos.StreamingConfig(
134
+ response_modalities=[protos.ResponseModality.AUDIO], # Use enum
135
+ tools=[google_search_tool]
136
+ )
137
+ # If you need to force tool usage or set mode:
138
+ # streaming_config.tool_config.CopyFrom(protos.ToolConfig(
139
+ # function_calling_config=protos.FunctionCallingConfig(mode=protos.FunctionCallingConfig.Mode.ANY)
140
+ # ))
141
+
142
+
143
+ async with self.client.aio.live.connect(
144
+ model="gemini-2.0-flash-exp", # Or "gemini-1.5-flash-latest" which is known to support tools well
145
+ config=streaming_config,
146
  ) as session:
147
  self.session = session
148
+ print("Gemini session started.")
149
  while not self.quit.is_set():
150
  turn = self.session.receive()
151
  try:
152
+ async for response_proto in turn: # response_proto is protos.Response
153
+ # Check for function calls
154
+ if response_proto.function_call and response_proto.function_call.name:
155
+ fc = response_proto.function_call
156
+ if fc.name == "perform_google_search":
157
+ query = fc.args["query"]
158
+ num_results = fc.args.get("num_results", 3) # Get optional num_results
159
+ print(f"Gemini requested Google search for: '{query}' with {num_results} results.")
160
+
161
+ search_results_text = await perform_google_search_async(query, int(num_results))
162
+ print(f"Search results (first 200 chars): {search_results_text[:200]}...")
163
+
164
+ # Send search results back to Gemini
165
+ function_response_proto = protos.FunctionResponse(
166
+ name="perform_google_search",
167
+ response={"result": search_results_text} # Response must be a dict/struct
168
+ )
169
+ input_proto = protos.Input(function_response=function_response_proto)
170
+ await self.session.send(input=input_proto)
171
+ print("Sent search results back to Gemini.")
172
+
173
+ # Handle audio data
174
+ elif response_proto.audio_output and response_proto.audio_output.data:
175
+ data = response_proto.audio_output.data
176
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
177
+ self.audio_queue.put_nowait(audio)
178
+
179
+ # You could also handle response_proto.text_output if needed
180
+
181
  except websockets.exceptions.ConnectionClosedOK:
182
+ print("Gemini session connection closed normally.")
183
  break
184
+ except Exception as e:
185
+ print(f"Error in Gemini session receive loop: {e}")
186
+ # Consider how to handle errors, e.g., break or log and continue
187
+ break # For now, break on error
188
+ print("Exited Gemini session receive loop.")
189
+
190
 
191
  async def video_receive(self, frame: np.ndarray):
192
  self.video_queue.put_nowait(frame)
193
 
194
  if self.session:
195
+ current_time = time.time()
196
+ if current_time - self.last_frame_time > 1: # Send image every 1 second
197
+ self.last_frame_time = current_time
198
+ # The original code sends a dict. For v1alpha, it might need to be wrapped in protos.Input
199
+ # For simplicity, keeping as dict and assuming SDK handles it.
200
+ # If issues, wrap: image_part = protos.Part(inline_data=protos.Blob(mime_type="image/jpeg", data=...))
201
+ # input_proto = protos.Input(parts=[image_part])
202
+ # await self.session.send(input=input_proto)
203
  await self.session.send(input=encode_image(frame))
204
+
205
+ # Handle additional image input from Gradio UI
206
+ if self.latest_args and len(self.latest_args) > 1 and self.latest_args[1] is not None:
207
+ # Assuming self.latest_args[1] is the numpy array from the gr.Image input
208
+ uploaded_image_data = self.latest_args[1]
209
+ await self.session.send(input=encode_image(uploaded_image_data))
210
+ # To avoid resending, you might want to clear it after sending
211
+ # self.latest_args[1] = None # Or handle state more robustly
212
+
213
 
214
  async def video_emit(self):
215
  frame = await wait_for_item(self.video_queue, 0.01)
216
  if frame is not None:
217
  return frame
218
  else:
219
+ return np.zeros((100, 100, 3), dtype=np.uint8) # Default blank frame
220
 
221
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
222
+ # Audio from user's microphone
223
  _, array = frame
224
  array = array.squeeze()
225
+ audio_message_dict = encode_audio(array) # This is a dict
226
+
227
  if self.session:
228
+ # For v1alpha, input should be protos.Input.
229
+ # The SDK might convert the dict, but explicit is safer.
230
+ audio_data_bytes = base64.b64decode(audio_message_dict["data"])
231
+ audio_part = protos.Part(
232
+ audio_input=protos.AudioData(
233
+ audio=audio_data_bytes,
234
+ # sample_rate_hertz=self.input_sample_rate # If API needs it
235
+ )
236
+ )
237
+ input_proto = protos.Input(parts=[audio_part])
238
+ await self.session.send(input=input_proto)
239
+
240
 
241
  async def emit(self):
242
+ # Audio to user's speakers (from Gemini)
243
  array = await wait_for_item(self.audio_queue, 0.01)
244
  if array is not None:
245
  return (self.output_sample_rate, array)
246
+ return None # Return None if no audio, Gradio handles it
247
 
248
  async def shutdown(self) -> None:
249
+ print("Shutting down GeminiHandler...")
250
  if self.session:
251
  self.quit.set()
252
+ try:
253
+ await self.session.close()
254
+ print("Gemini session closed.")
255
+ except Exception as e:
256
+ print(f"Error closing Gemini session: {e}")
257
+ self.quit.clear()
258
+ # Clean up client if necessary, though it's managed by 'async with' in start_up
259
+ self.client = None
260
 
261
 
262
  stream = Stream(
 
272
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
273
  "pulse_color": "rgb(255, 255, 255)",
274
  "icon_button_color": "rgb(255, 255, 255)",
275
+ "title": "Gemini Audio Video Chat with Search", # Updated title
276
  },
277
  )
278
 
 
288
  <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
289
  </div>
290
  <div>
291
+ <h1>Gen AI SDK Voice Chat with Google Search</h1>
292
+ <p>Speak with Gemini using real-time audio + video streaming, now with Google Search capability!</p>
293
+ <p>Try saying: "Search for the weather in London" or "Google the latest AI news."</p>
294
  <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
295
+ <p>Get a Gemini API Key <a href="https://aistudio.google.com/app/apikey">here</a>. You'll also need a Google Search API Key and CSE ID.</p>
296
  </div>
297
  </div>
298
  """
 
311
  )
312
  with gr.Column():
313
  image_input = gr.Image(
314
+ label="Image (optional, sent with video frames)", type="numpy", sources=["upload", "clipboard"]
315
  )
316
 
317
+ # The WebRTC.stream method will pass these inputs to the handler's methods.
318
+ # The handler's __init__ or other methods might need to store/access `image_input` if needed beyond `latest_args`.
319
+ # The `latest_args` in `video_receive` comes from the `inputs` list here.
320
  webrtc.stream(
321
+ GeminiHandler(), # A new instance of GeminiHandler for each stream session
322
+ inputs=[webrtc, image_input], # webrtc is args[0], image_input is args[1]
323
  outputs=[webrtc],
324
  time_limit=180 if get_space() else None,
325
  concurrency_limit=2 if get_space() else None,
 
329
 
330
 
331
  if __name__ == "__main__":
332
+ if not os.getenv("GEMINI_API_KEY"):
333
+ print("GEMINI_API_KEY not found in environment variables. Please set it in your .env file.")
334
+ if not GOOGLE_SEARCH_API_KEY or not GOOGLE_CSE_ID:
335
+ print("GOOGLE_SEARCH_API_KEY or GOOGLE_CSE_ID not found. Search functionality will be limited.")
336
+
337
  if (mode := os.getenv("MODE")) == "UI":
338
  stream.ui.launch(server_port=7860)
339
  elif mode == "PHONE":
340
  raise ValueError("Phone mode not supported for this demo")
341
  else:
342
+ # Default to UI launch if MODE is not set or unrecognized
343
+ print("Launching Gradio UI...")
344
+ stream.ui.launch(server_port=7860)