File size: 13,697 Bytes
d342f59
 
 
405ac70
 
 
 
b7a0a78
2d3e55a
fb9d6a0
405ac70
 
 
 
 
 
 
 
 
eab6f19
 
405ac70
2d3e55a
405ac70
b7a0a78
 
 
 
 
2d3e55a
 
405ac70
 
b7a0a78
 
 
 
 
 
 
 
2d3e55a
405ac70
fb9d6a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a0a78
fb9d6a0
 
 
 
 
 
 
 
 
 
ba0e279
 
 
fb9d6a0
 
 
 
b7a0a78
fb9d6a0
 
 
b7a0a78
fb9d6a0
 
 
 
 
 
 
b7a0a78
2d3e55a
 
b7a0a78
 
 
 
2d3e55a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a0a78
 
2d3e55a
 
 
 
405ac70
 
b7a0a78
 
405ac70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a0a78
405ac70
 
b7a0a78
 
 
 
 
 
 
 
2d3e55a
 
b7a0a78
405ac70
b7a0a78
 
 
 
2d3e55a
b7a0a78
 
 
405ac70
 
b7a0a78
405ac70
b7a0a78
 
 
8a5cf65
b7a0a78
8a5cf65
b7a0a78
 
 
 
 
 
2d3e55a
b7a0a78
 
2d3e55a
 
b7a0a78
405ac70
2d3e55a
b7a0a78
2d3e55a
b7a0a78
405ac70
b7a0a78
8a5cf65
 
 
 
b7a0a78
8a5cf65
 
b7a0a78
 
 
 
8a5cf65
 
 
405ac70
b7a0a78
2d3e55a
b7a0a78
 
 
 
405ac70
 
 
b7a0a78
 
 
 
 
 
2d3e55a
405ac70
 
 
 
 
 
b7a0a78
 
 
 
 
 
 
 
 
2d3e55a
405ac70
 
b7a0a78
2d3e55a
b7a0a78
8a5cf65
 
 
 
 
 
 
e45ab03
 
8a5cf65
e45ab03
2d3e55a
 
 
 
 
7bc94a1
2d3e55a
7bc94a1
2d3e55a
e45ab03
405ac70
b7a0a78
405ac70
 
 
 
2d3e55a
 
 
 
 
 
 
b7a0a78
2d3e55a
 
 
 
 
 
 
 
 
 
 
 
 
7bc94a1
 
2d3e55a
 
 
 
 
 
 
405ac70
2d3e55a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a0a78
2d3e55a
 
 
 
 
 
7bc94a1
2d3e55a
7bc94a1
2d3e55a
 
 
 
7bc94a1
2d3e55a
7bc94a1
2d3e55a
 
 
 
 
 
 
 
 
 
 
 
 
7bc94a1
b7a0a78
2d3e55a
 
 
 
7bc94a1
b7a0a78
405ac70
2d3e55a
 
 
 
 
405ac70
7bc94a1
 
 
 
 
 
 
 
2d3e55a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# https://huggingface.co/spaces/freddyaboulton/gemini-audio-video-chat
# related demos: https://github.com/freddyaboulton/gradio-webrtc

import asyncio
import base64
import os
import time
import logging
import traceback
import cv2

import gradio as gr
import numpy as np
from google import genai
from gradio_webrtc import (
    AsyncAudioVideoStreamHandler,
    WebRTC,
    async_aggregate_bytes_to_16bit,
    VideoEmitType,
    AudioEmitType,
    get_twilio_turn_credentials,
)
import requests  # Use requests for synchronous Twilio check

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Global State ---
twilio_available = None  # Will be set *before* Gradio initialization
gemini_connected = False


# --- Helper Functions ---
def encode_audio(data: np.ndarray) -> dict:
    if not isinstance(data, np.ndarray):
        raise TypeError("encode_audio expected a numpy.ndarray")
    try:
        return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
    except Exception as e:
        logger.error(f"Error encoding audio: {e}")
        raise

def encode_image(data: np.ndarray, quality: int = 85) -> dict:
    """
    Encodes a NumPy array (image) to a JPEG, Base64-encoded UTF-8 string using OpenCV.
    Handles various input data types.

    Args:
        data: A NumPy array of shape (n, n, 3).
        quality: JPEG quality (0-100).

    Returns:
        A dictionary with keys "mime_type" and "data".

    Raises:
        TypeError: If input is not a NumPy array.
        ValueError: If input shape is incorrect or contains NaN/Inf.
        Exception: If JPEG encoding fails.
    """

    # Input Validation (shape and dimensions)
    if not isinstance(data, np.ndarray):
        raise TypeError("Input must be a NumPy array.")
    if data.ndim != 3 or data.shape[2] != 3:
        raise ValueError("Input array must have shape (n, n, 3).")
    if 0 in data.shape:
        raise ValueError("Input array cannot have a dimension of size 0.")

    # Handle NaN/Inf (regardless of data type)
    if np.any(np.isnan(data)) or np.any(np.isinf(data)):
        raise ValueError("Input array contains NaN or Inf")

    # Normalize and convert to uint8
    if np.issubdtype(data.dtype, np.floating) or np.issubdtype(data.dtype, np.integer):
        scaled_data = cv2.normalize(data, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    else:
        raise TypeError("Input array must have a floating-point or integer data type.")

    # JPEG Encoding (with quality control and error handling)
    try:
        retval, buf = cv2.imencode(".jpg", scaled_data, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
        if not retval:
            raise Exception("cv2.imencode failed")
    except Exception as e:
        raise Exception(f"JPEG encoding failed: {e}")

    # Base64 Encoding
    jpeg_bytes = np.array(buf).tobytes()
    base64_str = base64.b64encode(jpeg_bytes).decode('utf-8')

    return {"mime_type": "image/jpeg", "data": base64_str}

def check_twilio_availability_sync() -> bool:
    """Checks Twilio TURN server availability (synchronous version)."""
    global twilio_available
    retries = 3
    delay = 2

    for attempt in range(retries):
        try:
            logger.info(f"Attempting to get Twilio credentials (attempt {attempt + 1})...")
            credentials = get_twilio_turn_credentials()
            logger.info(f"Twilio credentials response: {credentials}")
            if credentials:
                twilio_available = True
                logger.info("Twilio TURN server available.")
                return True
        except requests.exceptions.RequestException as e:
            logger.warning(f"Attempt {attempt + 1}: {e}")
            logger.warning(traceback.format_exc())
            if attempt < retries - 1:
                time.sleep(delay)
        except Exception as e:
            logger.exception(f"Unexpected error checking Twilio: {e}")
            twilio_available = False
            return False

    twilio_available = False
    logger.warning("Twilio TURN server unavailable.")
    return False



# --- Gemini Handler Class ---
class GeminiHandler(AsyncAudioVideoStreamHandler):
    def __init__(
        self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
    ) -> None:
        super().__init__(
            expected_layout,
            output_sample_rate,
            output_frame_size,
            input_sample_rate=16000,
        )
        self.audio_queue = asyncio.Queue()
        self.video_queue = asyncio.Queue()
        self.quit = asyncio.Event()
        self.session = None
        self.last_frame_time = 0

    def copy(self) -> "GeminiHandler":
        return GeminiHandler(
            expected_layout=self.expected_layout,
            output_sample_rate=self.output_sample_rate,
            output_frame_size=self.output_frame_size,
        )

    async def video_receive(self, frame: np.ndarray):
        if self.session:
            try:
                if time.time() - self.last_frame_time > 1:
                    self.last_frame_time = time.time()
                    await self.session.send(encode_image(frame))
                    if self.latest_args[2] is not None:
                        await self.session.send(encode_image(self.latest_args[2]))
            except Exception as e:
                logger.error(f"Error sending video frame: {e}")
                gr.Warning("Error sending video to Gemini.")
        self.video_queue.put_nowait(frame)

    async def video_emit(self) -> VideoEmitType:
        try:
            return await self.video_queue.get()
        except asyncio.CancelledError:
            logger.info("Video emit cancelled.")
            return None
        except Exception as e:
            logger.exception(f"Error in video_emit: {e}")
            return None

    async def connect(self, api_key: str):
        global gemini_connected
        if self.session is None:
            try:
                client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
                config = {"response_modalities": ["AUDIO"]}
                async with client.aio.live.connect(
                    model="gemini-2.0-flash-exp", config=config
                ) as session:
                    self.session = session
                    gemini_connected = True
                    asyncio.create_task(self.receive_audio())
                    await self.quit.wait()
            except Exception as e:
                logger.error(f"Error connecting to Gemini: {e}")
                gemini_connected = False
                self.shutdown()
                gr.Warning(f"Failed to connect to Gemini: {e}")
            finally:
                update_gemini_status_sync()

    async def generator(self):
        if not self.session:
            logger.warning("Gemini session is not initialized.")
            return

        while not self.quit.is_set():
            try:
                await asyncio.sleep(0)  # Yield to the event loop
                if self.quit.is_set():
                    break
                turn = self.session.receive()
                async for response in turn:
                    if self.quit.is_set():
                        break # Exit inner loop if quit is set.
                    if data := response.data:
                        yield data
            except Exception as e:
                logger.error(f"Error receiving from Gemini: {e}")
                self.quit.set() # set quit if we error.
                break

    async def receive_audio(self):
        try:
            async for audio_response in async_aggregate_bytes_to_16bit(self.generator()):
                self.audio_queue.put_nowait(audio_response)
        except Exception as e:
            logger.exception(f"Error in receive_audio: {e}")

    async def receive(self, frame: tuple[int, np.ndarray]) -> None:
        _, array = frame
        array = array.squeeze()
        try:
            audio_message = encode_audio(array)
            if self.session:
                await self.session.send(audio_message)
        except Exception as e:
            logger.error(f"Error sending audio: {e}")
            gr.Warning("Error sending audio to Gemini.")

    async def emit(self) -> AudioEmitType:
        if not self.args_set.is_set():
            await self.wait_for_args()
        if self.session is None:
            asyncio.create_task(self.connect(self.latest_args[1]))

        try:
            array = await self.audio_queue.get()
            return (self.output_sample_rate, array)
        except asyncio.CancelledError:
            logger.info("Audio emit cancelled.")
            return (self.output_sample_rate, np.array([]))
        except Exception as e:
            logger.exception(f"Error in emit: {e}")
            return (self.output_sample_rate, np.array([]))

    def shutdown(self) -> None:
        global gemini_connected
        gemini_connected = False
        logger.info("Shutting down GeminiHandler.")
        if self.session:
            try:
                #  await self.session.close()  # There is no async close
                pass
            except Exception:
                pass
        self.quit.set()  # Set quit *after* attempting to close the session
        self.connection = None
        self.args_set.clear()

        self.quit.clear()
        update_gemini_status_sync()


def update_gemini_status_sync():
    """Updates the Gemini status message (synchronous version)."""
    status = "βœ… Gemini: Connected" if gemini_connected else "❌ Gemini: Disconnected"
    if 'demo' in locals() and demo.running:
        gr.update(value=status)



# --- Gradio UI ---
css = """
#video-source {max-width: 600px !important; max-height: 600 !important;}
"""

# Perform Twilio check *before* Gradio UI definition (synchronously)
if __name__ == "__main__":
    check_twilio_availability_sync()


with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
    <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
        <div style="background-color: var(--block-background-fill); border-radius: 8px">
            <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
        </div>
        <div>
            <h1>Gen AI SDK Voice Chat</h1>
            <p>Speak with Gemini using real-time audio + video streaming</p>
            <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚑️</p>
            <p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
        </div>
    </div>
    """
    )
    twilio_status_message = gr.Markdown("❓ Twilio: Checking...")
    gemini_status_message = gr.Markdown("❓ Gemini: Checking...")

    with gr.Row() as api_key_row:
        api_key = gr.Textbox(
            label="API Key",
            type="password",
            placeholder="Enter your API Key",
            value=os.getenv("GOOGLE_API_KEY"),
        )
    with gr.Row(visible=False) as row:
        with gr.Column():
            # Set rtc_configuration based on the *pre-checked* twilio_available
            rtc_config = get_twilio_turn_credentials() if twilio_available else None
            # Explicitly specify codecs (example - you might need to adjust)
            if rtc_config:
                rtc_config['codecs'] = ['VP8', 'H264']  # Prefer VP8, then H.264
            webrtc = WebRTC(
                label="Video Chat",
                modality="audio-video",
                mode="send-receive",
                elem_id="video-source",
                rtc_configuration=rtc_config,
                icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
                pulse_color="rgb(35, 157, 225)",
                icon_button_color="rgb(35, 157, 225)",
            )
        with gr.Column():
            image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])


    def update_twilio_status_ui():
        if twilio_available:
            message = "βœ… Twilio: Available"
        else:
            message = "❌ Twilio: Unavailable (connection may be less reliable)"
        return gr.update(value=message)

    demo.load(update_twilio_status_ui, [], [twilio_status_message])

    handler = GeminiHandler()
    webrtc.stream(
        handler,
        inputs=[webrtc, api_key, image_input],
        outputs=[webrtc],
        time_limit=90,
        concurrency_limit=None,
    )


    def check_api_key(api_key_str):
        if not api_key_str:
            return (
                gr.update(visible=True),
                gr.update(visible=False),
                gr.update(value="Please enter a valid API key"),
                gr.update(value="❓ Gemini: Checking..."),
            )
        return (
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(value=""),
            gr.update(value="❓ Gemini: Checking..."),
        )

    api_key.submit(
        check_api_key,
        [api_key],
        [api_key_row, row, twilio_status_message, gemini_status_message],
    )

    # If API key is already set via environment variables, hide the API key row and show content
    if os.getenv("GOOGLE_API_KEY"):
        demo.load(
            lambda: (gr.update(visible=False), gr.update(visible=True)),
            None,
            [api_key_row, row],
        )

demo.launch()