mgokg commited on
Commit
3d9d966
·
verified ·
1 Parent(s): d08d2ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -192
app.py CHANGED
@@ -3,7 +3,6 @@ import base64
3
  import os
4
  import time
5
  from io import BytesIO
6
- import functools # Added for to_thread
7
 
8
  import gradio as gr
9
  import numpy as np
@@ -17,72 +16,11 @@ from fastrtc import (
17
  wait_for_item,
18
  )
19
  from google import genai
20
- # Ensure genai.protos is accessible for Tool, FunctionDeclaration etc.
21
- # If not, you might need from google.generativeai.types import Tool, FunctionDeclaration, Schema, Part, Content
22
- # However, with live.connect using v1alpha, direct proto usage is often needed.
23
- from google.generativeai import protos # Explicitly import protos
24
  from gradio.utils import get_space
25
  from PIL import Image
26
- from googleapiclient.discovery import build # Added for Google Search
27
 
28
  load_dotenv()
29
 
30
- # --- Environment Variables for Google Search ---
31
- GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
32
- GOOGLE_CSE_ID = os.getenv("GOOGLE_CSE_ID")
33
-
34
- # --- Google Search Function ---
35
- async def perform_google_search_async(query: str, num_results: int = 3) -> str:
36
- """
37
- Performs a Google search using the Custom Search API and returns formatted results.
38
- """
39
- if not GOOGLE_SEARCH_API_KEY or not GOOGLE_CSE_ID:
40
- print("Google Search API key or CSE ID not configured.")
41
- return "Search functionality is not configured."
42
- try:
43
- loop = asyncio.get_running_loop()
44
- # Create a partial function for the blocking call
45
- partial_search = functools.partial(
46
- build("customsearch", "v1", developerKey=GOOGLE_SEARCH_API_KEY).cse().list(
47
- q=query, cx=GOOGLE_CSE_ID, num=num_results
48
- ).execute
49
- )
50
- # Run the blocking call in a thread pool
51
- res = await loop.run_in_executor(None, partial_search)
52
-
53
- if 'items' in res and res['items']:
54
- results = []
55
- for item in res['items']:
56
- title = item.get('title', 'N/A')
57
- link = item.get('link', 'N/A')
58
- snippet = item.get('snippet', 'N/A').replace("\n", " ")
59
- results.append(f"Title: {title}\nLink: {link}\nSnippet: {snippet}\n---")
60
- return "\n".join(results)
61
- else:
62
- return "No search results found."
63
- except Exception as e:
64
- print(f"Error during Google Search: {e}")
65
- return f"An error occurred while searching: {str(e)}"
66
-
67
- # --- Define the Google Search Tool for Gemini ---
68
- # Using genai.protos directly as LiveSession client might expect raw protos
69
- google_search_tool = protos.Tool(
70
- function_declarations=[
71
- protos.FunctionDeclaration(
72
- name="perform_google_search",
73
- description="Performs a Google search for a given query and returns a summary of the top results. Use this for general web searches or finding specific information online.",
74
- parameters=protos.Schema(
75
- type=protos.Type.OBJECT,
76
- properties={
77
- "query": protos.Schema(type=protos.Type.STRING, description="The search query to use for Google Search."),
78
- "num_results": protos.Schema(type=protos.Type.NUMBER, description="Optional. Number of search results to return (default is 3).")
79
- },
80
- required=["query"]
81
- )
82
- )
83
- ]
84
- )
85
-
86
 
87
  def encode_audio(data: np.ndarray) -> dict:
88
  """Encode Audio data to send to the server"""
@@ -115,148 +53,68 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
115
  self.session = None
116
  self.last_frame_time = 0
117
  self.quit = asyncio.Event()
118
- self.client = None # Store client
119
 
120
  def copy(self) -> "GeminiHandler":
121
  return GeminiHandler()
122
 
123
  async def start_up(self):
124
- self.client = genai.Client( # Use self.client
125
  api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
126
  )
127
-
128
- # Configure Gemini to use the search tool
129
- # Note: For v1alpha live client, config might be a dict or protos.StreamingConfig
130
- # protos.ToolConfig and protos.FunctionCallingConfig might be needed for more control
131
- # e.g. tool_config=protos.ToolConfig(function_calling_config=protos.FunctionCallingConfig(mode=protos.FunctionCallingConfig.Mode.ANY))
132
-
133
- streaming_config = protos.StreamingConfig(
134
- response_modalities=[protos.ResponseModality.AUDIO], # Use enum
135
- tools=[google_search_tool]
136
- )
137
- # If you need to force tool usage or set mode:
138
- # streaming_config.tool_config.CopyFrom(protos.ToolConfig(
139
- # function_calling_config=protos.FunctionCallingConfig(mode=protos.FunctionCallingConfig.Mode.ANY)
140
- # ))
141
-
142
-
143
- async with self.client.aio.live.connect(
144
- model="gemini-2.0-flash-exp", # Or "gemini-1.5-flash-latest" which is known to support tools well
145
- config=streaming_config,
146
  ) as session:
147
  self.session = session
148
- print("Gemini session started.")
149
  while not self.quit.is_set():
150
  turn = self.session.receive()
151
  try:
152
- async for response_proto in turn: # response_proto is protos.Response
153
- # Check for function calls
154
- if response_proto.function_call and response_proto.function_call.name:
155
- fc = response_proto.function_call
156
- if fc.name == "perform_google_search":
157
- query = fc.args["query"]
158
- num_results = fc.args.get("num_results", 3) # Get optional num_results
159
- print(f"Gemini requested Google search for: '{query}' with {num_results} results.")
160
-
161
- search_results_text = await perform_google_search_async(query, int(num_results))
162
- print(f"Search results (first 200 chars): {search_results_text[:200]}...")
163
-
164
- # Send search results back to Gemini
165
- function_response_proto = protos.FunctionResponse(
166
- name="perform_google_search",
167
- response={"result": search_results_text} # Response must be a dict/struct
168
- )
169
- input_proto = protos.Input(function_response=function_response_proto)
170
- await self.session.send(input=input_proto)
171
- print("Sent search results back to Gemini.")
172
-
173
- # Handle audio data
174
- elif response_proto.audio_output and response_proto.audio_output.data:
175
- data = response_proto.audio_output.data
176
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
177
- self.audio_queue.put_nowait(audio)
178
-
179
- # You could also handle response_proto.text_output if needed
180
-
181
  except websockets.exceptions.ConnectionClosedOK:
182
- print("Gemini session connection closed normally.")
183
  break
184
- except Exception as e:
185
- print(f"Error in Gemini session receive loop: {e}")
186
- # Consider how to handle errors, e.g., break or log and continue
187
- break # For now, break on error
188
- print("Exited Gemini session receive loop.")
189
-
190
 
191
  async def video_receive(self, frame: np.ndarray):
192
  self.video_queue.put_nowait(frame)
193
 
194
  if self.session:
195
- current_time = time.time()
196
- if current_time - self.last_frame_time > 1: # Send image every 1 second
197
- self.last_frame_time = current_time
198
- # The original code sends a dict. For v1alpha, it might need to be wrapped in protos.Input
199
- # For simplicity, keeping as dict and assuming SDK handles it.
200
- # If issues, wrap: image_part = protos.Part(inline_data=protos.Blob(mime_type="image/jpeg", data=...))
201
- # input_proto = protos.Input(parts=[image_part])
202
- # await self.session.send(input=input_proto)
203
  await self.session.send(input=encode_image(frame))
204
-
205
- # Handle additional image input from Gradio UI
206
- if self.latest_args and len(self.latest_args) > 1 and self.latest_args[1] is not None:
207
- # Assuming self.latest_args[1] is the numpy array from the gr.Image input
208
- uploaded_image_data = self.latest_args[1]
209
- await self.session.send(input=encode_image(uploaded_image_data))
210
- # To avoid resending, you might want to clear it after sending
211
- # self.latest_args[1] = None # Or handle state more robustly
212
-
213
 
214
  async def video_emit(self):
215
  frame = await wait_for_item(self.video_queue, 0.01)
216
  if frame is not None:
217
  return frame
218
  else:
219
- return np.zeros((100, 100, 3), dtype=np.uint8) # Default blank frame
220
 
221
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
222
- # Audio from user's microphone
223
  _, array = frame
224
  array = array.squeeze()
225
- audio_message_dict = encode_audio(array) # This is a dict
226
-
227
  if self.session:
228
- # For v1alpha, input should be protos.Input.
229
- # The SDK might convert the dict, but explicit is safer.
230
- audio_data_bytes = base64.b64decode(audio_message_dict["data"])
231
- audio_part = protos.Part(
232
- audio_input=protos.AudioData(
233
- audio=audio_data_bytes,
234
- # sample_rate_hertz=self.input_sample_rate # If API needs it
235
- )
236
- )
237
- input_proto = protos.Input(parts=[audio_part])
238
- await self.session.send(input=input_proto)
239
-
240
 
241
  async def emit(self):
242
- # Audio to user's speakers (from Gemini)
243
  array = await wait_for_item(self.audio_queue, 0.01)
244
  if array is not None:
245
  return (self.output_sample_rate, array)
246
- return None # Return None if no audio, Gradio handles it
247
 
248
  async def shutdown(self) -> None:
249
- print("Shutting down GeminiHandler...")
250
  if self.session:
251
  self.quit.set()
252
- try:
253
- await self.session.close()
254
- print("Gemini session closed.")
255
- except Exception as e:
256
- print(f"Error closing Gemini session: {e}")
257
- self.quit.clear()
258
- # Clean up client if necessary, though it's managed by 'async with' in start_up
259
- self.client = None
260
 
261
 
262
  stream = Stream(
@@ -272,7 +130,7 @@ stream = Stream(
272
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
273
  "pulse_color": "rgb(255, 255, 255)",
274
  "icon_button_color": "rgb(255, 255, 255)",
275
- "title": "Gemini Audio Video Chat with Search", # Updated title
276
  },
277
  )
278
 
@@ -283,18 +141,12 @@ css = """
283
  with gr.Blocks(css=css) as demo:
284
  gr.HTML(
285
  """
286
- <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
287
- <div style="background-color: var(--block-background-fill); border-radius: 8px">
288
- <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
289
- </div>
290
  <div>
291
- <h1>Gen AI SDK Voice Chat with Google Search</h1>
292
- <p>Speak with Gemini using real-time audio + video streaming, now with Google Search capability!</p>
293
- <p>Try saying: "Search for the weather in London" or "Google the latest AI news."</p>
294
- <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
295
- <p>Get a Gemini API Key <a href="https://aistudio.google.com/app/apikey">here</a>. You'll also need a Google Search API Key and CSE ID.</p>
296
- </div>
297
- </div>
298
  """
299
  )
300
  with gr.Row() as row:
@@ -309,17 +161,14 @@ with gr.Blocks(css=css) as demo:
309
  pulse_color="rgb(255, 255, 255)",
310
  icon_button_color="rgb(255, 255, 255)",
311
  )
312
- with gr.Column():
313
- image_input = gr.Image(
314
- label="Image (optional, sent with video frames)", type="numpy", sources=["upload", "clipboard"]
315
- )
316
 
317
- # The WebRTC.stream method will pass these inputs to the handler's methods.
318
- # The handler's __init__ or other methods might need to store/access `image_input` if needed beyond `latest_args`.
319
- # The `latest_args` in `video_receive` comes from the `inputs` list here.
320
  webrtc.stream(
321
- GeminiHandler(), # A new instance of GeminiHandler for each stream session
322
- inputs=[webrtc, image_input], # webrtc is args[0], image_input is args[1]
323
  outputs=[webrtc],
324
  time_limit=180 if get_space() else None,
325
  concurrency_limit=2 if get_space() else None,
@@ -329,16 +178,9 @@ stream.ui = demo
329
 
330
 
331
  if __name__ == "__main__":
332
- if not os.getenv("GEMINI_API_KEY"):
333
- print("GEMINI_API_KEY not found in environment variables. Please set it in your .env file.")
334
- if not GOOGLE_SEARCH_API_KEY or not GOOGLE_CSE_ID:
335
- print("GOOGLE_SEARCH_API_KEY or GOOGLE_CSE_ID not found. Search functionality will be limited.")
336
-
337
  if (mode := os.getenv("MODE")) == "UI":
338
  stream.ui.launch(server_port=7860)
339
  elif mode == "PHONE":
340
  raise ValueError("Phone mode not supported for this demo")
341
  else:
342
- # Default to UI launch if MODE is not set or unrecognized
343
- print("Launching Gradio UI...")
344
- stream.ui.launch(server_port=7860)
 
3
  import os
4
  import time
5
  from io import BytesIO
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
16
  wait_for_item,
17
  )
18
  from google import genai
 
 
 
 
19
  from gradio.utils import get_space
20
  from PIL import Image
 
21
 
22
  load_dotenv()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def encode_audio(data: np.ndarray) -> dict:
26
  """Encode Audio data to send to the server"""
 
53
  self.session = None
54
  self.last_frame_time = 0
55
  self.quit = asyncio.Event()
 
56
 
57
  def copy(self) -> "GeminiHandler":
58
  return GeminiHandler()
59
 
60
  async def start_up(self):
61
+ client = genai.Client(
62
  api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
63
  )
64
+ config = {"response_modalities": ["AUDIO"]}
65
+ async with client.aio.live.connect(
66
+ model="gemini-2.0-flash-exp",
67
+ config=config, # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ) as session:
69
  self.session = session
 
70
  while not self.quit.is_set():
71
  turn = self.session.receive()
72
  try:
73
+ async for response in turn:
74
+ if data := response.data:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
76
+ self.audio_queue.put_nowait(audio)
 
 
 
77
  except websockets.exceptions.ConnectionClosedOK:
78
+ print("connection closed")
79
  break
 
 
 
 
 
 
80
 
81
  async def video_receive(self, frame: np.ndarray):
82
  self.video_queue.put_nowait(frame)
83
 
84
  if self.session:
85
+ # send image every 1 second
86
+ print(time.time() - self.last_frame_time)
87
+ if time.time() - self.last_frame_time > 1:
88
+ self.last_frame_time = time.time()
 
 
 
 
89
  await self.session.send(input=encode_image(frame))
90
+ if self.latest_args[1] is not None:
91
+ await self.session.send(input=encode_image(self.latest_args[1]))
 
 
 
 
 
 
 
92
 
93
  async def video_emit(self):
94
  frame = await wait_for_item(self.video_queue, 0.01)
95
  if frame is not None:
96
  return frame
97
  else:
98
+ return np.zeros((100, 100, 3), dtype=np.uint8)
99
 
100
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
 
101
  _, array = frame
102
  array = array.squeeze()
103
+ audio_message = encode_audio(array)
 
104
  if self.session:
105
+ await self.session.send(input=audio_message)
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  async def emit(self):
 
108
  array = await wait_for_item(self.audio_queue, 0.01)
109
  if array is not None:
110
  return (self.output_sample_rate, array)
111
+ return array
112
 
113
  async def shutdown(self) -> None:
 
114
  if self.session:
115
  self.quit.set()
116
+ await self.session.close()
117
+ self.quit.clear()
 
 
 
 
 
 
118
 
119
 
120
  stream = Stream(
 
130
  "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
131
  "pulse_color": "rgb(255, 255, 255)",
132
  "icon_button_color": "rgb(255, 255, 255)",
133
+ "title": "Gemini Audio Video Chat",
134
  },
135
  )
136
 
 
141
  with gr.Blocks(css=css) as demo:
142
  gr.HTML(
143
  """
 
 
 
 
144
  <div>
145
+ <center>
146
+ <h1>Gen AI Voice Chat</h1>
147
+ <p>real-time audio + video streaming</p>
148
+ </center>
149
+ </div>
 
 
150
  """
151
  )
152
  with gr.Row() as row:
 
161
  pulse_color="rgb(255, 255, 255)",
162
  icon_button_color="rgb(255, 255, 255)",
163
  )
164
+ #with gr.Column():
165
+ #image_input = gr.Image(
166
+ #label="Image", type="numpy", sources=["upload", "clipboard"]
167
+ #)
168
 
 
 
 
169
  webrtc.stream(
170
+ GeminiHandler(),
171
+ inputs=[webrtc],
172
  outputs=[webrtc],
173
  time_limit=180 if get_space() else None,
174
  concurrency_limit=2 if get_space() else None,
 
178
 
179
 
180
  if __name__ == "__main__":
 
 
 
 
 
181
  if (mode := os.getenv("MODE")) == "UI":
182
  stream.ui.launch(server_port=7860)
183
  elif mode == "PHONE":
184
  raise ValueError("Phone mode not supported for this demo")
185
  else:
186
+ stream.ui.launch(server_port=7860)