Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,843 +1,191 @@
|
|
1 |
-
import os
|
2 |
import asyncio
|
3 |
import base64
|
4 |
-
import
|
5 |
-
import
|
6 |
-
import
|
7 |
-
|
8 |
-
import PIL.Image
|
9 |
-
import mss
|
10 |
-
import mss.tools
|
11 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from google import genai
|
13 |
-
from
|
|
|
14 |
|
15 |
-
|
16 |
-
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
|
17 |
-
if not GEMINI_API_KEY:
|
18 |
-
raise ValueError("GEMINI_API_KEY environment variable not set.")
|
19 |
|
20 |
-
# Audio settings
|
21 |
-
PYAUDIO_FORMAT = pyaudio.paInt16
|
22 |
-
CHANNELS = 1
|
23 |
-
SEND_SAMPLE_RATE = 16000 # Sample rate for audio sent to Gemini
|
24 |
-
RECEIVE_SAMPLE_RATE = 24000 # Sample rate for audio received from Gemini (Puck voice)
|
25 |
-
CHUNK_SIZE = 1024
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
# Streaming Modes
|
31 |
-
VIDEO_MODE_CAMERA = "camera"
|
32 |
-
VIDEO_MODE_SCREEN = "screen"
|
33 |
-
VIDEO_MODE_NONE = "none" # Added for audio/text only
|
34 |
-
DEFAULT_VIDEO_MODE = VIDEO_MODE_CAMERA
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
-
# --- GeminiStreamingClient Class ---
|
42 |
-
class GeminiStreamingClient:
|
43 |
-
def __init__(self, video_mode=DEFAULT_VIDEO_MODE,
|
44 |
-
on_text_received=None, on_audio_received=None, on_error=None):
|
45 |
-
self.video_mode = video_mode
|
46 |
-
self.on_text_received = on_text_received
|
47 |
-
self.on_audio_received = on_audio_received
|
48 |
-
self.on_error = on_error
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
self
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
self.capture_device = None # For camera
|
59 |
-
|
60 |
-
self.genai_client = genai.GenerativeModel(
|
61 |
-
MODEL_NAME,
|
62 |
-
system_instruction=types.Content(
|
63 |
-
parts=[types.Part.from_text(text="Du bist ein hilfreicher Assistent. Antworte immer auf Deutsch.")],
|
64 |
-
role="user" # System instructions are typically role 'user' or 'model' then 'user'
|
65 |
-
)
|
66 |
)
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
async def _get_frame_bytes(self):
|
74 |
-
if not self.capture_device or not self.capture_device.isOpened():
|
75 |
-
print("Camera not initialized or opened.")
|
76 |
-
await asyncio.sleep(1) # Prevent tight loop if camera fails
|
77 |
-
return None
|
78 |
-
ret, frame = await asyncio.to_thread(self.capture_device.read)
|
79 |
-
if not ret:
|
80 |
-
print("Failed to grab frame from camera.")
|
81 |
-
return None
|
82 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
83 |
-
img = PIL.Image.fromarray(frame_rgb)
|
84 |
-
img.thumbnail((1024, 1024)) # Resize
|
85 |
-
image_io = io.BytesIO()
|
86 |
-
img.save(image_io, format="JPEG")
|
87 |
-
image_io.seek(0)
|
88 |
-
return {"mime_type": "image/jpeg", "data": base64.b64encode(image_io.read()).decode()}
|
89 |
-
|
90 |
-
async def _get_screen_bytes(self):
|
91 |
-
with mss.mss() as sct:
|
92 |
-
monitor = sct.monitors[1] # Primary monitor
|
93 |
-
sct_img = sct.grab(monitor)
|
94 |
-
img = PIL.Image.frombytes("RGB", sct_img.size, sct_img.rgb, "raw", "RGB")
|
95 |
-
image_io = io.BytesIO()
|
96 |
-
img.save(image_io, format="JPEG")
|
97 |
-
image_io.seek(0)
|
98 |
-
return {"mime_type": "image/jpeg", "data": base64.b64encode(image_io.read()).decode()}
|
99 |
-
|
100 |
-
async def _stream_visual_media(self):
|
101 |
-
if self.video_mode == VIDEO_MODE_CAMERA:
|
102 |
-
self.capture_device = await asyncio.to_thread(cv2.VideoCapture, 0)
|
103 |
-
if not self.capture_device.isOpened():
|
104 |
-
print("Error: Could not open camera.")
|
105 |
-
if self.on_error: self.on_error("Could not open camera.")
|
106 |
-
return
|
107 |
-
get_media_bytes = self._get_frame_bytes
|
108 |
-
elif self.video_mode == VIDEO_MODE_SCREEN:
|
109 |
-
get_media_bytes = self._get_screen_bytes
|
110 |
-
else: # VIDEO_MODE_NONE or unknown
|
111 |
-
return # No visual media to stream
|
112 |
-
|
113 |
-
while self.is_running:
|
114 |
-
try:
|
115 |
-
media_data = await get_media_bytes()
|
116 |
-
if media_data:
|
117 |
-
await self.media_out_queue.put(media_data)
|
118 |
-
await asyncio.sleep(1.0) # Capture frame every second
|
119 |
-
except Exception as e:
|
120 |
-
print(f"Error in visual media stream: {e}")
|
121 |
-
traceback.print_exc()
|
122 |
-
if self.on_error: self.on_error(f"Visual media stream error: {e}")
|
123 |
-
await asyncio.sleep(1) # Avoid tight loop on error
|
124 |
-
|
125 |
-
if self.video_mode == VIDEO_MODE_CAMERA and self.capture_device:
|
126 |
-
self.capture_device.release()
|
127 |
-
self.capture_device = None
|
128 |
-
|
129 |
-
async def _listen_microphone(self):
|
130 |
-
try:
|
131 |
-
mic_info = self.pya.get_default_input_device_info()
|
132 |
-
self.mic_stream = self.pya.open(
|
133 |
-
format=PYAUDIO_FORMAT,
|
134 |
-
channels=CHANNELS,
|
135 |
-
rate=SEND_SAMPLE_RATE,
|
136 |
-
input=True,
|
137 |
-
input_device_index=mic_info["index"],
|
138 |
-
frames_per_buffer=CHUNK_SIZE,
|
139 |
-
)
|
140 |
-
except Exception as e:
|
141 |
-
print(f"Error opening microphone stream: {e}")
|
142 |
-
if self.on_error: self.on_error(f"Microphone error: {e}")
|
143 |
-
return
|
144 |
-
|
145 |
-
print("Microphone listener started.")
|
146 |
-
while self.is_running and self.mic_stream:
|
147 |
-
try:
|
148 |
-
data = await asyncio.to_thread(self.mic_stream.read, CHUNK_SIZE, exception_on_overflow=False)
|
149 |
-
await self.media_out_queue.put({"mime_type": "audio/pcm", "data": data, "sample_rate": SEND_SAMPLE_RATE})
|
150 |
-
except IOError as e: # Stream closed or other issue
|
151 |
-
if e.errno == pyaudio.paInputOverflowed:
|
152 |
-
print("Input overflowed. Skipping.")
|
153 |
-
continue
|
154 |
-
print(f"Error reading from microphone: {e}")
|
155 |
-
if self.on_error: self.on_error(f"Mic read error: {e}")
|
156 |
-
break # Exit loop on significant error
|
157 |
-
except Exception as e:
|
158 |
-
print(f"Unexpected error in microphone listener: {e}")
|
159 |
-
traceback.print_exc()
|
160 |
-
if self.on_error: self.on_error(f"Mic listener error: {e}")
|
161 |
-
break
|
162 |
-
print("Microphone listener stopped.")
|
163 |
-
|
164 |
-
|
165 |
-
async def _play_gemini_audio(self):
|
166 |
-
try:
|
167 |
-
self.speaker_stream = self.pya.open(
|
168 |
-
format=PYAUDIO_FORMAT,
|
169 |
-
channels=CHANNELS,
|
170 |
-
rate=RECEIVE_SAMPLE_RATE,
|
171 |
-
output=True
|
172 |
-
)
|
173 |
-
except Exception as e:
|
174 |
-
print(f"Error opening speaker stream: {e}")
|
175 |
-
if self.on_error: self.on_error(f"Speaker error: {e}")
|
176 |
-
return
|
177 |
-
|
178 |
-
print("Audio playback started.")
|
179 |
-
while self.is_running or not self.audio_in_queue.empty(): # Process remaining queue even if stopping
|
180 |
-
try:
|
181 |
-
audio_chunk = await self.audio_in_queue.get()
|
182 |
-
if audio_chunk is None: # Sentinel for stopping
|
183 |
-
break
|
184 |
-
if self.speaker_stream:
|
185 |
-
await asyncio.to_thread(self.speaker_stream.write, audio_chunk)
|
186 |
-
self.audio_in_queue.task_done()
|
187 |
-
except Exception as e:
|
188 |
-
print(f"Error playing audio: {e}")
|
189 |
-
if self.on_error: self.on_error(f"Audio playback error: {e}")
|
190 |
-
# Don't break the loop for playback errors, just log and continue
|
191 |
-
print("Audio playback stopped.")
|
192 |
-
|
193 |
-
async def _process_gemini_responses(self):
|
194 |
-
print("Starting to process Gemini responses...")
|
195 |
-
try:
|
196 |
-
# The new API uses generate_content with a stream of Parts
|
197 |
-
# We need to build up the content to send.
|
198 |
-
# This part needs careful handling of how media_out_queue items are consumed.
|
199 |
-
# For a continuous conversation, we'd typically send an initial prompt,
|
200 |
-
# then subsequent inputs (audio/video/text) as parts of the ongoing conversation.
|
201 |
-
|
202 |
-
# This simplified model sends one "turn" at a time based on media_out_queue.
|
203 |
-
# A more robust solution would manage conversation history.
|
204 |
-
|
205 |
-
# The `generate_content(stream=True)` expects an iterable of `Part` or `Content` objects.
|
206 |
-
# We'll create a generator that yields content from our `media_out_queue`.
|
207 |
-
async def content_generator():
|
208 |
-
while self.is_running or not self.media_out_queue.empty():
|
209 |
-
try:
|
210 |
-
item = await asyncio.wait_for(self.media_out_queue.get(), timeout=0.1)
|
211 |
-
if item is None: # Sentinel
|
212 |
-
break
|
213 |
-
|
214 |
-
content_parts = []
|
215 |
-
if "text" in item:
|
216 |
-
content_parts.append(types.Part.from_text(item["text"]))
|
217 |
-
elif item["mime_type"].startswith("image/"):
|
218 |
-
content_parts.append(types.Part.from_data(
|
219 |
-
data=base64.b64decode(item["data"]),
|
220 |
-
mime_type=item["mime_type"]
|
221 |
-
))
|
222 |
-
elif item["mime_type"] == "audio/pcm":
|
223 |
-
# For audio, it's better to send it as part of the overall content
|
224 |
-
# The API expects audio data directly.
|
225 |
-
content_parts.append(types.Part.from_data(
|
226 |
-
data=item["data"],
|
227 |
-
mime_type=item["mime_type"] # or audio/wav if converted
|
228 |
-
))
|
229 |
-
|
230 |
-
if content_parts:
|
231 |
-
# print(f"Sending content to Gemini: {content_parts}")
|
232 |
-
yield types.Content(parts=content_parts, role="user") # Each item from queue is a new user turn
|
233 |
-
self.media_out_queue.task_done()
|
234 |
-
|
235 |
-
except asyncio.TimeoutError:
|
236 |
-
continue # No new media, continue checking
|
237 |
-
except Exception as e:
|
238 |
-
print(f"Error in content_generator: {e}")
|
239 |
-
if self.on_error: self.on_error(f"Content gen error: {e}")
|
240 |
-
break
|
241 |
-
print("Content generator finished.")
|
242 |
-
|
243 |
-
# Configuration for audio output (if model supports it directly)
|
244 |
-
# This is a bit different from the old LiveConnectConfig
|
245 |
-
generation_config = types.GenerationConfig(
|
246 |
-
# candidate_count=1, # default
|
247 |
-
# stop_sequences=[],
|
248 |
-
# max_output_tokens=2048, # default
|
249 |
-
# temperature=0.9, # default
|
250 |
-
response_mime_type="audio/pcm", # Request audio output
|
251 |
-
response_schema=types.Schema(
|
252 |
-
type=types.Type.OBJECT,
|
253 |
-
properties={
|
254 |
-
'audio_data': types.Schema(type=types.Type.STRING, format="byte", description="The audio data in PCM format."),
|
255 |
-
'text_response': types.Schema(type=types.Type.STRING, description="The textual part of the response.")
|
256 |
-
}
|
257 |
-
)
|
258 |
-
)
|
259 |
-
# Note: The above response_schema is an example. The actual way to get audio
|
260 |
-
# might be simpler if the model directly outputs it when asked.
|
261 |
-
# For "Puck" voice, it was part of LiveConnect. With GenAI API, it's more complex.
|
262 |
-
# Let's assume the model can provide audio if `response_mime_type="audio/pcm"` is set
|
263 |
-
# and the model supports it. If not, we'd need a separate TTS step.
|
264 |
-
# For now, we'll primarily focus on text responses and playing them if audio is somehow provided.
|
265 |
-
|
266 |
-
# The `stream=True` with `generate_content` is for streaming *responses*.
|
267 |
-
# For streaming *requests*, the input to `generate_content` should be an iterable.
|
268 |
-
|
269 |
-
# This is a conceptual challenge: `generate_content` is typically called once per "turn".
|
270 |
-
# To have a continuous stream of input media, we might need to use the lower-level
|
271 |
-
# `chat_session` or structure this differently.
|
272 |
-
# For simplicity, let's assume we send a batch of media from the queue as one turn.
|
273 |
-
|
274 |
-
# Let's re-think: The original code used `client.aio.live.connect` which is for a persistent session.
|
275 |
-
# `GenerativeModel.generate_content` is more for request-response, even if streamed.
|
276 |
-
# To replicate the live feel, we might need to send messages in a loop.
|
277 |
-
|
278 |
-
# For now, let's simplify: `_process_gemini_responses` will handle one "turn" when `send_text_input` is called.
|
279 |
-
# The background audio/video will be collected and sent with that text.
|
280 |
-
# This is a deviation from the original continuous stream but fits `generate_content` better.
|
281 |
-
# A true continuous bi-directional stream might require the (now less common) Live API or a different approach.
|
282 |
-
|
283 |
-
# Let's revert to a model closer to the original `session.receive()` if possible,
|
284 |
-
# or adapt to how `generate_content(stream=True)` works for responses.
|
285 |
-
# The `stream=True` in `generate_content` means the *response* is streamed.
|
286 |
-
# We need to send our `media_out_queue` items as part of the *request*.
|
287 |
-
|
288 |
-
# This part is tricky with the standard GenAI Python SDK for a "live" feel.
|
289 |
-
# The original code used a specific `live.connect` endpoint.
|
290 |
-
# If we stick to `GenerativeModel`, we'd typically do:
|
291 |
-
# model.start_chat()
|
292 |
-
# response = chat.send_message(..., stream=True)
|
293 |
-
# This is still turn-based.
|
294 |
-
|
295 |
-
# Let's assume the goal is to send a collection of media (text, last audio, last image)
|
296 |
-
# and get a streamed response.
|
297 |
-
|
298 |
-
# This method will now be triggered by `send_text_input`.
|
299 |
-
# The `media_out_queue` will be drained to form the content for `send_message`.
|
300 |
-
# This is a significant change from the original's continuous background sending.
|
301 |
-
pass # This method will be effectively merged into `send_text_input` logic for now.
|
302 |
-
|
303 |
-
|
304 |
-
async def start(self):
|
305 |
-
if self.is_running:
|
306 |
-
print("Session already running.")
|
307 |
-
return
|
308 |
-
print("Starting Gemini streaming client...")
|
309 |
-
self.is_running = True
|
310 |
-
|
311 |
-
self.tasks.append(asyncio.create_task(self._listen_microphone()))
|
312 |
-
if self.video_mode != VIDEO_MODE_NONE:
|
313 |
-
self.tasks.append(asyncio.create_task(self._stream_visual_media()))
|
314 |
-
self.tasks.append(asyncio.create_task(self._play_gemini_audio()))
|
315 |
-
# self.tasks.append(asyncio.create_task(self._process_gemini_responses())) # Now handled differently
|
316 |
-
|
317 |
-
print(f"Client started with video mode: {self.video_mode}. Tasks: {len(self.tasks)}")
|
318 |
-
|
319 |
-
async def stop(self):
|
320 |
-
if not self.is_running:
|
321 |
-
print("Session not running.")
|
322 |
-
return
|
323 |
-
print("Stopping Gemini streaming client...")
|
324 |
-
self.is_running = False
|
325 |
-
|
326 |
-
# Signal media processing loops to stop
|
327 |
-
if self.video_mode != VIDEO_MODE_NONE and self.capture_device:
|
328 |
-
if self.video_mode == VIDEO_MODE_CAMERA: # Only release if it's cv2.VideoCapture
|
329 |
-
if self.capture_device.isOpened():
|
330 |
-
self.capture_device.release()
|
331 |
-
self.capture_device = None
|
332 |
-
|
333 |
-
if self.mic_stream:
|
334 |
-
self.mic_stream.stop_stream()
|
335 |
-
self.mic_stream.close()
|
336 |
-
self.mic_stream = None
|
337 |
-
|
338 |
-
await self.media_out_queue.put(None) # Sentinel for content generator if it were still separate
|
339 |
-
await self.audio_in_queue.put(None) # Sentinel for audio player
|
340 |
-
|
341 |
-
# Cancel and await tasks
|
342 |
-
for task in self.tasks:
|
343 |
-
task.cancel()
|
344 |
-
await asyncio.gather(*self.tasks, return_exceptions=True)
|
345 |
-
self.tasks = []
|
346 |
-
|
347 |
-
if self.speaker_stream:
|
348 |
-
self.speaker_stream.stop_stream()
|
349 |
-
self.speaker_stream.close()
|
350 |
-
self.speaker_stream = None
|
351 |
-
|
352 |
-
self.pya.terminate()
|
353 |
-
print("Client stopped.")
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
if self.on_error: self.on_error("Session not active. Cannot send text.")
|
358 |
-
return
|
359 |
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
375 |
break
|
376 |
-
temp_media_holder.append(media_item)
|
377 |
-
self.media_out_queue.task_done()
|
378 |
-
except asyncio.QueueEmpty:
|
379 |
-
break
|
380 |
-
|
381 |
-
# Prioritize one image and some recent audio
|
382 |
-
last_image_part = None
|
383 |
-
audio_parts = []
|
384 |
-
|
385 |
-
for item in reversed(temp_media_holder): # Process most recent first
|
386 |
-
if item["mime_type"].startswith("image/") and not last_image_part:
|
387 |
-
last_image_part = types.Part.from_data(
|
388 |
-
data=base64.b64decode(item["data"]),
|
389 |
-
mime_type=item["mime_type"]
|
390 |
-
)
|
391 |
-
elif item["mime_type"] == "audio/pcm" and len(audio_parts) < 5: # Limit audio segments
|
392 |
-
# The API expects raw bytes for audio/pcm
|
393 |
-
audio_parts.append(types.Part.from_data(data=item["data"], mime_type=item["mime_type"]))
|
394 |
-
|
395 |
-
if last_image_part:
|
396 |
-
content_parts.append(last_image_part)
|
397 |
-
content_parts.extend(reversed(audio_parts)) # Add audio in chronological order
|
398 |
-
|
399 |
-
# Re-queue any unused media items (not ideal, but simple for now)
|
400 |
-
# for item in temp_media_holder:
|
401 |
-
# if (item["mime_type"].startswith("image/") and item != last_image_part_source) or \
|
402 |
-
# (item["mime_type"] == "audio/pcm" and item not in audio_parts_sources):
|
403 |
-
# await self.media_out_queue.put(item)
|
404 |
-
|
405 |
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
#
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
for part in chunk.parts:
|
423 |
-
if part.text:
|
424 |
-
# print(part.text, end="", flush=True) # Stream text to console
|
425 |
-
full_response_text += part.text
|
426 |
-
if self.on_text_received: # Callback for Gradio UI
|
427 |
-
# Send incremental text for streaming display
|
428 |
-
await self.on_text_received(part.text, is_final=False)
|
429 |
-
|
430 |
-
# Check for audio data - this part is speculative for generate_content
|
431 |
-
# as direct audio output like "Puck" voice was specific to LiveConnect
|
432 |
-
# If the model returns audio, it would likely be in `part.data` or `part.audio_data`
|
433 |
-
# For example, if `response_mime_type="audio/pcm"` worked and returned raw bytes:
|
434 |
-
if hasattr(part, 'data') and part.mime_type and part.mime_type.startswith("audio/"):
|
435 |
-
print(f"Received audio chunk of type {part.mime_type}")
|
436 |
-
await self.audio_in_queue.put(part.data)
|
437 |
-
elif hasattr(part, 'inline_data') and part.inline_data.mime_type.startswith("audio/"):
|
438 |
-
# This is how function calling results with audio might look
|
439 |
-
print(f"Received inline audio chunk of type {part.inline_data.mime_type}")
|
440 |
-
await self.audio_in_queue.put(part.inline_data.data)
|
441 |
-
|
442 |
-
|
443 |
-
if self.on_text_received and full_response_text: # Send final accumulated text
|
444 |
-
await self.on_text_received(full_response_text, is_final=True)
|
445 |
-
# print() # Newline after streaming full response
|
446 |
-
|
447 |
-
except Exception as e:
|
448 |
-
print(f"Error during Gemini communication: {e}")
|
449 |
-
traceback.print_exc()
|
450 |
-
if self.on_error: self.on_error(f"Gemini API error: {e}")
|
451 |
-
if self.on_text_received: # Clear any partial text
|
452 |
-
await self.on_text_received(f"Error: {e}", is_final=True)
|
453 |
-
|
454 |
-
|
455 |
-
# --- Gradio UI ---
|
456 |
-
async def build_gradio_app():
|
457 |
-
# Gradio State
|
458 |
-
chat_history_var = gr.State([])
|
459 |
-
client_session_var = gr.State(None)
|
460 |
-
current_bot_message_var = gr.State("") # To accumulate streaming response
|
461 |
-
|
462 |
-
async def handle_text_input(text_input, chat_history, client_session, current_bot_message):
|
463 |
-
if not client_session:
|
464 |
-
gr.Warning("Session not started. Please start the session first.")
|
465 |
-
return chat_history, "", current_bot_message # No change to text input
|
466 |
-
|
467 |
-
# Add user message to chat
|
468 |
-
chat_history.append({"role": "user", "content": text_input})
|
469 |
-
|
470 |
-
# Clear current bot message accumulator before new response
|
471 |
-
current_bot_message = ""
|
472 |
-
|
473 |
-
# Send text to Gemini client (which will also pick up queued media)
|
474 |
-
# The client will use callbacks to update the chat_history for the bot's response
|
475 |
-
asyncio.create_task(client_session.send_text_input(text_input))
|
476 |
-
|
477 |
-
# Return updated history (user part) and clear input box
|
478 |
-
# Bot response will be added via callback
|
479 |
-
return chat_history, "", current_bot_message
|
480 |
-
|
481 |
-
|
482 |
-
async def update_chatbot_display(text_chunk, is_final, chat_history, current_bot_message):
|
483 |
-
if not chat_history: # Should not happen if user message was added
|
484 |
-
chat_history.append({"role": "assistant", "content": ""})
|
485 |
-
|
486 |
-
if is_final:
|
487 |
-
# If it's the final message, ensure the last entry is the complete bot message
|
488 |
-
if chat_history and chat_history[-1]["role"] == "assistant":
|
489 |
-
chat_history[-1]["content"] = current_bot_message + text_chunk # Append final chunk
|
490 |
-
else: # Should not happen if streaming correctly
|
491 |
-
chat_history.append({"role": "assistant", "content": current_bot_message + text_chunk})
|
492 |
-
current_bot_message = "" # Reset accumulator
|
493 |
else:
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
# The final text will update the chatbot.
|
565 |
-
# A more advanced Gradio setup would use `gr.Textbox.stream` or similar.
|
566 |
-
|
567 |
-
# This callback will be passed to the client.
|
568 |
-
# It needs to update `chat_history_var` and `current_bot_message_var`
|
569 |
-
# and then trigger an update of the `chatbot_display`.
|
570 |
-
# This is complex because Gradio's state updates are tied to its event loop.
|
571 |
-
|
572 |
-
# Let's simplify: the callback will update a shared queue, and Gradio will poll it.
|
573 |
-
# Or, for this example, let the callback directly try to update,
|
574 |
-
# understanding it might have issues if not on Gradio's main thread.
|
575 |
-
# The `update_chatbot_display` will be the target.
|
576 |
-
|
577 |
-
# This is where the refactor gets tricky with Gradio's model.
|
578 |
-
# The `update_chatbot_display` function is designed to be a Gradio output function.
|
579 |
-
# We can't easily call it directly with new state.
|
580 |
-
|
581 |
-
# Alternative: The `on_text_received` callback in the client will update `chat_history_var`
|
582 |
-
# and `current_bot_message_var` directly. Then, we need a way to make Gradio re-render
|
583 |
-
# the chatbot. This is often done by having the function that *triggers* the action
|
584 |
-
# also return the updated chatbot state.
|
585 |
-
|
586 |
-
# For now, let's make the callback simple:
|
587 |
-
# It will print, and we'll handle the final update in `handle_text_input` response.
|
588 |
-
# This means no live streaming text in Gradio UI, only final response.
|
589 |
-
# To get live streaming, `handle_text_input` would need to yield updates.
|
590 |
-
|
591 |
-
# Let's try to make `update_chatbot_display` usable as a callback target
|
592 |
-
# by having it update the state variables.
|
593 |
-
|
594 |
-
# This is a conceptual placeholder. The actual update will be managed
|
595 |
-
# by how `handle_text_input` is structured if it were to support streaming yields.
|
596 |
-
# For now, the client's `on_text_received` will be simpler.
|
597 |
-
pass
|
598 |
-
|
599 |
-
|
600 |
-
async def start_stop_session(action, video_mode, current_client_session, current_chat_history, current_bot_msg):
|
601 |
-
if action == "Start Session":
|
602 |
-
if current_client_session:
|
603 |
-
gr.Info("Session already active.")
|
604 |
-
return "Stop Session", current_client_session, current_chat_history, current_bot_msg, gr.Button(interactive=True)
|
605 |
-
|
606 |
-
gr.Info(f"Starting session with mode: {video_mode}...")
|
607 |
-
|
608 |
-
# This list will store text chunks for the chatbot
|
609 |
-
# It will be updated by the callback
|
610 |
-
_chat_history_accumulator = list(current_chat_history) # Make a mutable copy
|
611 |
-
_current_bot_message_accumulator = str(current_bot_msg)
|
612 |
-
|
613 |
-
async def ui_on_text_received(text_chunk, is_final):
|
614 |
-
nonlocal _current_bot_message_accumulator # Allow modification
|
615 |
-
# This callback is tricky because it needs to update Gradio UI elements
|
616 |
-
# which are typically updated by returning values from event handlers.
|
617 |
-
# For streaming, event handlers can be generators (yield).
|
618 |
-
# Here, the text comes from a background task.
|
619 |
-
print(f"UI_TEXT_CB: {text_chunk[:50]}... (final: {is_final})")
|
620 |
-
|
621 |
-
if is_final:
|
622 |
-
if _chat_history_accumulator and _chat_history_accumulator[-1]["role"] == "assistant":
|
623 |
-
_chat_history_accumulator[-1]["content"] = _current_bot_message_accumulator + text_chunk
|
624 |
-
else:
|
625 |
-
_chat_history_accumulator.append({"role": "assistant", "content": _current_bot_message_accumulator + text_chunk})
|
626 |
-
_current_bot_message_accumulator = ""
|
627 |
-
else:
|
628 |
-
_current_bot_message_accumulator += text_chunk
|
629 |
-
if _chat_history_accumulator and _chat_history_accumulator[-1]["role"] == "assistant":
|
630 |
-
_chat_history_accumulator[-1]["content"] = _current_bot_message_accumulator
|
631 |
-
elif not _chat_history_accumulator or _chat_history_accumulator[-1]["role"] == "user":
|
632 |
-
_chat_history_accumulator.append({"role": "assistant", "content": _current_bot_message_accumulator})
|
633 |
-
|
634 |
-
# This callback doesn't directly return to Gradio to update the UI.
|
635 |
-
# The UI update for chatbot will happen when `handle_text_input` completes
|
636 |
-
# or if `start_stop_session` could yield updates (it can't easily for background events).
|
637 |
-
# This is a common challenge with Gradio and background tasks updating UI.
|
638 |
-
# For now, `_chat_history_accumulator` is updated, and `handle_text_input` will use it.
|
639 |
-
|
640 |
-
|
641 |
-
def ui_on_audio_received(audio_chunk):
|
642 |
-
# PyAudio in the client handles playback. This is for potential Gradio audio out.
|
643 |
-
# print(f"UI_AUDIO_CB: Received audio chunk of {len(audio_chunk)} bytes.")
|
644 |
-
# This would yield to gr.Audio if we used it for playback.
|
645 |
-
pass # Let PyAudio in client handle playback
|
646 |
-
|
647 |
-
def ui_on_error(error_msg):
|
648 |
-
gr.Error(f"Session Error: {error_msg}")
|
649 |
-
# Potentially try to stop the session here or update UI state
|
650 |
-
print(f"UI_ERROR_CB: {error_msg}")
|
651 |
-
|
652 |
-
|
653 |
-
client = GeminiStreamingClient(
|
654 |
-
video_mode=video_mode,
|
655 |
-
on_text_received=ui_on_text_received, # This callback needs to update Gradio state
|
656 |
-
on_audio_received=ui_on_audio_received,
|
657 |
-
on_error=ui_on_error
|
658 |
)
|
659 |
-
|
660 |
-
gr.
|
661 |
-
|
662 |
-
# This is indirect. The callback `ui_on_text_received` should ideally update the gr.State
|
663 |
-
# For now, we pass the accumulator list.
|
664 |
-
return "Stop Session", client, _chat_history_accumulator, _current_bot_message_accumulator, gr.Button(interactive=True)
|
665 |
-
|
666 |
-
elif action == "Stop Session":
|
667 |
-
if current_client_session:
|
668 |
-
gr.Info("Stopping session...")
|
669 |
-
await current_client_session.stop()
|
670 |
-
gr.Info("Session stopped.")
|
671 |
-
return "Start Session", None, current_chat_history, "", gr.Button(interactive=True) # Clear client, keep history
|
672 |
-
gr.Info("No active session to stop.")
|
673 |
-
return "Start Session", None, current_chat_history, "", gr.Button(interactive=True)
|
674 |
-
|
675 |
-
|
676 |
-
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
677 |
-
gr.Markdown("# Gemini Live Streaming Chat")
|
678 |
-
gr.Markdown(f"Using Model: `{MODEL_NAME}`. Ensure your `GEMINI_API_KEY` is set.")
|
679 |
-
|
680 |
-
with gr.Row():
|
681 |
-
video_mode_dropdown = gr.Dropdown(
|
682 |
-
choices=[VIDEO_MODE_CAMERA, VIDEO_MODE_SCREEN, VIDEO_MODE_NONE],
|
683 |
-
value=DEFAULT_VIDEO_MODE,
|
684 |
-
label="Video/Screen Input Mode",
|
685 |
-
interactive=True
|
686 |
)
|
687 |
-
start_stop_button = gr.Button("Start Session")
|
688 |
-
|
689 |
-
chatbot_display = gr.Chatbot(
|
690 |
-
label="Conversation",
|
691 |
-
bubble_full_width=False,
|
692 |
-
height=600
|
693 |
-
)
|
694 |
-
# This audio output is for if Gemini sends audio that Gradio should play.
|
695 |
-
# Our client plays it via PyAudio, so this might be redundant or for different use.
|
696 |
-
# audio_output_display = gr.Audio(label="Gemini Response Audio", autoplay=True, streaming=True, interactive=False)
|
697 |
-
|
698 |
-
text_input_box = gr.Textbox(
|
699 |
-
label="Send a message",
|
700 |
-
placeholder="Type your message here or just talk (if mic is active)...",
|
701 |
-
interactive=True
|
702 |
-
)
|
703 |
-
submit_button = gr.Button("Send", interactive=False) # Disabled until session starts
|
704 |
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
fn=start_stop_session,
|
712 |
-
inputs=[start_stop_button, video_mode_dropdown, client_session_var, chat_history_var, current_bot_message_var],
|
713 |
-
outputs=[start_stop_button, client_session_var, chat_history_var, current_bot_message_var, text_input_box, submit_button] # Also update interactivity of text input
|
714 |
-
).then(
|
715 |
-
fn=start_stop_button_update,
|
716 |
-
inputs=[client_session_var],
|
717 |
-
outputs=[text_input_box, submit_button]
|
718 |
-
)
|
719 |
-
|
720 |
-
# When text is submitted (via Enter or Send button)
|
721 |
-
# This is where the main interaction logic for text input happens.
|
722 |
-
# It needs to be a generator to stream responses to the chatbot.
|
723 |
-
async def process_and_stream_text(text_input_val, chat_history_list, client_session_obj, current_bot_msg_val):
|
724 |
-
if not client_session_obj:
|
725 |
-
gr.Warning("Session not started.")
|
726 |
-
# Yield current state to avoid clearing input if session not started
|
727 |
-
gradio_chat_tuples = []
|
728 |
-
for msg in chat_history_list: # Convert history to display format
|
729 |
-
gradio_chat_tuples.append((msg.get("content") if msg.get("role")=="user" else None,
|
730 |
-
msg.get("content") if msg.get("role")=="assistant" else None))
|
731 |
-
yield gradio_chat_tuples, chat_history_list, text_input_val, current_bot_msg_val
|
732 |
-
return
|
733 |
-
|
734 |
-
# Add user message to chat history state
|
735 |
-
chat_history_list.append({"role": "user", "content": text_input_val})
|
736 |
-
|
737 |
-
# Convert to Gradio display format
|
738 |
-
gradio_chat_tuples = []
|
739 |
-
for msg in chat_history_list:
|
740 |
-
gradio_chat_tuples.append((msg.get("content") if msg.get("role")=="user" else None,
|
741 |
-
msg.get("content") if msg.get("role")=="assistant" else None))
|
742 |
-
|
743 |
-
# Yield user message immediately
|
744 |
-
yield gradio_chat_tuples, chat_history_list, "", current_bot_msg_val # Clear input box
|
745 |
-
|
746 |
-
# Prepare for bot's response (streaming)
|
747 |
-
# The client's on_text_received callback will update chat_history_list and current_bot_msg_val
|
748 |
-
# We need to make sure this function can access those updates.
|
749 |
-
# The `ui_on_text_received` callback in `start_stop_session` updates `_chat_history_accumulator`
|
750 |
-
# which is then assigned to `chat_history_var`.
|
751 |
-
|
752 |
-
# This is the tricky part: how to make this generator aware of updates from the background callback?
|
753 |
-
# One way: the callback sets a flag or puts data in a queue that this generator polls.
|
754 |
-
|
755 |
-
# Let's redefine the client's text callback for this specific Gradio streaming context
|
756 |
-
# This means the `GeminiStreamingClient` needs to be flexible with its callback.
|
757 |
-
|
758 |
-
# For now, let's assume `client_session_obj.send_text_input` will trigger the callback,
|
759 |
-
# and the callback updates `chat_history_list` and `current_bot_msg_val` (passed by reference).
|
760 |
-
|
761 |
-
# The `ui_on_text_received` callback (defined within `start_stop_session`)
|
762 |
-
# is already designed to modify `_chat_history_accumulator` and `_current_bot_message_accumulator`.
|
763 |
-
# When `start_stop_session` returns, these are put into `chat_history_var` and `current_bot_message_var`.
|
764 |
-
# So, `chat_history_list` and `current_bot_msg_val` here *should* be the updated ones.
|
765 |
-
|
766 |
-
# Trigger Gemini
|
767 |
-
# The actual streaming of Gemini's response to the UI will be handled by how `send_text_input`
|
768 |
-
# and its callbacks are set up. If `ui_on_text_received` can trigger a re-yield here, that's ideal.
|
769 |
-
# Gradio's streaming usually involves the event handler itself yielding multiple times.
|
770 |
-
|
771 |
-
# Let's simplify: `send_text_input` is fire-and-forget here.
|
772 |
-
# The `ui_on_text_received` callback will be responsible for updating the shared state.
|
773 |
-
# This generator needs to periodically check that shared state and yield.
|
774 |
-
|
775 |
-
# This is a simplified approach: we send the text, then we assume the callback
|
776 |
-
# `ui_on_text_received` (which is set up when the session starts) will update
|
777 |
-
# `chat_history_list` (which is `chat_history_var`).
|
778 |
-
# We then just need to yield the final state of `chat_history_list`.
|
779 |
-
# This won't give live character-by-character streaming in the UI from this function alone.
|
780 |
-
|
781 |
-
# To achieve true streaming in Gradio UI from a background task:
|
782 |
-
# 1. Background task (Gemini client) puts response chunks into an asyncio.Queue.
|
783 |
-
# 2. This Gradio event handler (`process_and_stream_text`) reads from that queue and yields.
|
784 |
-
|
785 |
-
# Let's modify `GeminiStreamingClient` to accept a queue for text output.
|
786 |
-
# For now, let's stick to the callback updating the shared `chat_history_list`.
|
787 |
-
# The `ui_on_text_received` callback needs to be robust.
|
788 |
-
|
789 |
-
# Send text to client. The callback `ui_on_text_received` (configured during session start)
|
790 |
-
# will update `chat_history_list` and `current_bot_msg_val` "in the background".
|
791 |
-
await client_session_obj.send_text_input(text_input_val)
|
792 |
-
|
793 |
-
# After `send_text_input` completes (which means Gemini finished responding),
|
794 |
-
# `chat_history_list` should contain the full conversation.
|
795 |
-
# The `ui_on_text_received` callback should have populated it.
|
796 |
-
|
797 |
-
final_gradio_tuples = []
|
798 |
-
for msg in chat_history_list: # chat_history_list is chat_history_var's value
|
799 |
-
final_gradio_tuples.append((msg.get("content") if msg.get("role")=="user" else None,
|
800 |
-
msg.get("content") if msg.get("role")=="assistant" else None))
|
801 |
-
|
802 |
-
# Yield the final state
|
803 |
-
yield final_gradio_tuples, chat_history_list, "", "" # Clear bot message accumulator too
|
804 |
-
|
805 |
-
text_input_box.submit(
|
806 |
-
fn=process_and_stream_text,
|
807 |
-
inputs=[text_input_box, chat_history_var, client_session_var, current_bot_message_var],
|
808 |
-
outputs=[chatbot_display, chat_history_var, text_input_box, current_bot_message_var]
|
809 |
-
)
|
810 |
-
submit_button.click(
|
811 |
-
fn=process_and_stream_text,
|
812 |
-
inputs=[text_input_box, chat_history_var, client_session_var, current_bot_message_var],
|
813 |
-
outputs=[chatbot_display, chat_history_var, text_input_box, current_bot_message_var]
|
814 |
)
|
815 |
-
|
816 |
-
# Graceful shutdown
|
817 |
-
async def on_close():
|
818 |
-
print("Gradio app is closing. Stopping client session if active.")
|
819 |
-
client_session = client_session_var.value # How to get state here? This is tricky.
|
820 |
-
# Gradio doesn't have a clean "on_shutdown" hook
|
821 |
-
# that easily accesses gr.State values from Python.
|
822 |
-
# This would typically be handled by the user clicking "Stop Session".
|
823 |
-
# For now, this is a placeholder. Proper cleanup requires careful state management.
|
824 |
-
# A better way is to ensure the user stops the session via the button.
|
825 |
-
# If `client_session_var` could be accessed here, we'd do:
|
826 |
-
# if client_session_var.value:
|
827 |
-
# await client_session_var.value.stop()
|
828 |
-
print("Cleanup logic in on_close needs robust state access or manual stop.")
|
829 |
|
830 |
-
|
831 |
|
832 |
-
return app
|
833 |
|
834 |
-
# --- Main Execution ---
|
835 |
if __name__ == "__main__":
|
836 |
-
if
|
837 |
-
|
838 |
-
|
|
|
839 |
else:
|
840 |
-
|
841 |
-
gradio_app = asyncio.run(build_gradio_app())
|
842 |
-
gradio_app.queue() # Enable queuing for handling multiple users or long processes
|
843 |
-
gradio_app.launch(debug=True) # Share=True for public link
|
|
|
|
|
1 |
import asyncio
|
2 |
import base64
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from io import BytesIO
|
6 |
+
|
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import websockets
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
from fastrtc import (
|
12 |
+
AsyncAudioVideoStreamHandler,
|
13 |
+
Stream,
|
14 |
+
WebRTC,
|
15 |
+
get_cloudflare_turn_credentials_async,
|
16 |
+
wait_for_item,
|
17 |
+
)
|
18 |
from google import genai
|
19 |
+
from gradio.utils import get_space
|
20 |
+
from PIL import Image
|
21 |
|
22 |
+
load_dotenv()
|
|
|
|
|
|
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
def encode_audio(data: np.ndarray) -> dict:
|
26 |
+
"""Encode Audio data to send to the server"""
|
27 |
+
return {
|
28 |
+
"mime_type": "audio/pcm",
|
29 |
+
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),
|
30 |
+
}
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
def encode_image(data: np.ndarray) -> dict:
|
34 |
+
with BytesIO() as output_bytes:
|
35 |
+
pil_image = Image.fromarray(data)
|
36 |
+
pil_image.save(output_bytes, "JPEG")
|
37 |
+
bytes_data = output_bytes.getvalue()
|
38 |
+
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
39 |
+
return {"mime_type": "image/jpeg", "data": base64_str}
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
class GeminiHandler(AsyncAudioVideoStreamHandler):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
) -> None:
|
46 |
+
super().__init__(
|
47 |
+
"mono",
|
48 |
+
output_sample_rate=24000,
|
49 |
+
input_sample_rate=16000,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
+
self.audio_queue = asyncio.Queue()
|
52 |
+
self.video_queue = asyncio.Queue()
|
53 |
+
self.session = None
|
54 |
+
self.last_frame_time = 0
|
55 |
+
self.quit = asyncio.Event()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
def copy(self) -> "GeminiHandler":
|
58 |
+
return GeminiHandler()
|
|
|
|
|
59 |
|
60 |
+
async def start_up(self):
|
61 |
+
client = genai.Client(
|
62 |
+
api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
|
63 |
+
)
|
64 |
+
config = {"response_modalities": ["AUDIO"]}
|
65 |
+
async with client.aio.live.connect(
|
66 |
+
model="gemini-2.0-flash-exp",
|
67 |
+
config=config, # type: ignore
|
68 |
+
) as session:
|
69 |
+
self.session = session
|
70 |
+
while not self.quit.is_set():
|
71 |
+
turn = self.session.receive()
|
72 |
+
try:
|
73 |
+
async for response in turn:
|
74 |
+
if data := response.data:
|
75 |
+
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
|
76 |
+
self.audio_queue.put_nowait(audio)
|
77 |
+
except websockets.exceptions.ConnectionClosedOK:
|
78 |
+
print("connection closed")
|
79 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
async def video_receive(self, frame: np.ndarray):
|
82 |
+
self.video_queue.put_nowait(frame)
|
83 |
+
|
84 |
+
if self.session:
|
85 |
+
# send image every 1 second
|
86 |
+
print(time.time() - self.last_frame_time)
|
87 |
+
if time.time() - self.last_frame_time > 1:
|
88 |
+
self.last_frame_time = time.time()
|
89 |
+
await self.session.send(input=encode_image(frame))
|
90 |
+
if self.latest_args[1] is not None:
|
91 |
+
await self.session.send(input=encode_image(self.latest_args[1]))
|
92 |
+
|
93 |
+
async def video_emit(self):
|
94 |
+
frame = await wait_for_item(self.video_queue, 0.01)
|
95 |
+
if frame is not None:
|
96 |
+
return frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
else:
|
98 |
+
return np.zeros((100, 100, 3), dtype=np.uint8)
|
99 |
+
|
100 |
+
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
101 |
+
_, array = frame
|
102 |
+
array = array.squeeze()
|
103 |
+
audio_message = encode_audio(array)
|
104 |
+
if self.session:
|
105 |
+
await self.session.send(input=audio_message)
|
106 |
+
|
107 |
+
async def emit(self):
|
108 |
+
array = await wait_for_item(self.audio_queue, 0.01)
|
109 |
+
if array is not None:
|
110 |
+
return (self.output_sample_rate, array)
|
111 |
+
return array
|
112 |
+
|
113 |
+
async def shutdown(self) -> None:
|
114 |
+
if self.session:
|
115 |
+
self.quit.set()
|
116 |
+
await self.session.close()
|
117 |
+
self.quit.clear()
|
118 |
+
|
119 |
+
|
120 |
+
stream = Stream(
|
121 |
+
handler=GeminiHandler(),
|
122 |
+
modality="audio-video",
|
123 |
+
mode="send-receive",
|
124 |
+
rtc_configuration=get_cloudflare_turn_credentials_async,
|
125 |
+
time_limit=180 if get_space() else None,
|
126 |
+
additional_inputs=[
|
127 |
+
gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
|
128 |
+
],
|
129 |
+
ui_args={
|
130 |
+
"icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
|
131 |
+
"pulse_color": "rgb(255, 255, 255)",
|
132 |
+
"icon_button_color": "rgb(255, 255, 255)",
|
133 |
+
"title": "Gemini Audio Video Chat",
|
134 |
+
},
|
135 |
+
)
|
136 |
+
|
137 |
+
css = """
|
138 |
+
#video-source {max-width: 600px !important; max-height: 600 !important;}
|
139 |
+
"""
|
140 |
+
|
141 |
+
with gr.Blocks(css=css) as demo:
|
142 |
+
gr.HTML(
|
143 |
+
"""
|
144 |
+
<div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
|
145 |
+
<div style="background-color: var(--block-background-fill); border-radius: 8px">
|
146 |
+
<img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
|
147 |
+
</div>
|
148 |
+
<div>
|
149 |
+
<h1>Gen AI SDK Voice Chat</h1>
|
150 |
+
<p>Speak with Gemini using real-time audio + video streaming</p>
|
151 |
+
<p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
|
152 |
+
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
|
153 |
+
</div>
|
154 |
+
</div>
|
155 |
+
"""
|
156 |
+
)
|
157 |
+
with gr.Row() as row:
|
158 |
+
with gr.Column():
|
159 |
+
webrtc = WebRTC(
|
160 |
+
label="Video Chat",
|
161 |
+
modality="audio-video",
|
162 |
+
mode="send-receive",
|
163 |
+
elem_id="video-source",
|
164 |
+
rtc_configuration=get_cloudflare_turn_credentials_async,
|
165 |
+
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
|
166 |
+
pulse_color="rgb(255, 255, 255)",
|
167 |
+
icon_button_color="rgb(255, 255, 255)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
)
|
169 |
+
with gr.Column():
|
170 |
+
image_input = gr.Image(
|
171 |
+
label="Image", type="numpy", sources=["upload", "clipboard"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
webrtc.stream(
|
175 |
+
GeminiHandler(),
|
176 |
+
inputs=[webrtc, image_input],
|
177 |
+
outputs=[webrtc],
|
178 |
+
time_limit=180 if get_space() else None,
|
179 |
+
concurrency_limit=2 if get_space() else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
stream.ui = demo
|
183 |
|
|
|
184 |
|
|
|
185 |
if __name__ == "__main__":
|
186 |
+
if (mode := os.getenv("MODE")) == "UI":
|
187 |
+
stream.ui.launch(server_port=7860)
|
188 |
+
elif mode == "PHONE":
|
189 |
+
raise ValueError("Phone mode not supported for this demo")
|
190 |
else:
|
191 |
+
stream.ui.launch(server_port=7860)
|
|
|
|
|
|