mgokg commited on
Commit
e4188aa
·
verified ·
1 Parent(s): 776b2e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -22
app.py CHANGED
@@ -1,3 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  async def start_up(self):
2
  client = genai.Client(
3
  api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
@@ -34,36 +94,116 @@
34
  turn = self.session.receive()
35
  try:
36
  async for response in turn:
37
- # Check if data exists before trying to process it as audio
38
  if data := response.data:
39
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
40
- self.audio_queue.put_nowait(audio) # Only put if audio was created
41
- # You might want to handle other parts of the response here
42
- # e.g., response.text, response.tool_code, etc.
43
- # For now, we just ensure we don't crash if data is None.
44
-
45
  except websockets.exceptions.ConnectionClosedOK:
46
  print("connection closed")
47
  break
48
- except Exception as e:
49
- # Catch other potential errors during response processing
50
- print(f"Error processing response: {e}")
51
- # Depending on the error, you might want to break or continue
52
- # For now, let's break to prevent infinite loops on persistent errors
53
- break
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
57
  _, array = frame
58
  array = array.squeeze()
59
  audio_message = encode_audio(array)
60
- # Add a check to ensure the session is still active before sending
61
- if self.session and not self.session._ws.closed: # Check if session exists and websocket is not closed
62
- try:
63
- await self.session.send(input=audio_message)
64
- except websockets.exceptions.ConnectionClosedOK:
65
- print("Attempted to send on a closed connection.")
66
- except Exception as e:
67
- print(f"Error sending audio message: {e}")
68
- else:
69
- print("Session not active, cannot send audio message.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import os
4
+ import time
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import websockets
10
+ from dotenv import load_dotenv
11
+ from fastrtc import (
12
+ AsyncAudioVideoStreamHandler,
13
+ Stream,
14
+ WebRTC,
15
+ get_cloudflare_turn_credentials_async,
16
+ wait_for_item,
17
+ )
18
+ from google import genai
19
+ from google.genai import types # Import the types module
20
+ from gradio.utils import get_space
21
+ from PIL import Image
22
+
23
+ load_dotenv()
24
+
25
+
26
+ def encode_audio(data: np.ndarray) -> dict:
27
+ """Encode Audio data to send to the server"""
28
+ return {
29
+ "mime_type": "audio/pcm",
30
+ "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
31
+ }
32
+
33
+
34
+ def encode_image(data: np.ndarray) -> dict:
35
+ with BytesIO() as output_bytes:
36
+ pil_image = Image.fromarray(data)
37
+ pil_image.save(output_bytes, "JPEG")
38
+ bytes_data = output_bytes.getvalue()
39
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
40
+ return {"mime_type": "image/jpeg", "data": base64_str}
41
+
42
+
43
+ class GeminiHandler(AsyncAudioVideoStreamHandler):
44
+ def __init__(
45
+ self,
46
+ ) -> None:
47
+ super().__init__(
48
+ "mono",
49
+ output_sample_rate=24000,
50
+ input_sample_rate=16000,
51
+ )
52
+ self.audio_queue = asyncio.Queue()
53
+ self.video_queue = asyncio.Queue()
54
+ self.session = None
55
+ self.last_frame_time = 0
56
+ self.quit = asyncio.Event()
57
+
58
+ def copy(self) -> "GeminiHandler":
59
+ return GeminiHandler()
60
+
61
  async def start_up(self):
62
  client = genai.Client(
63
  api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
 
94
  turn = self.session.receive()
95
  try:
96
  async for response in turn:
 
97
  if data := response.data:
98
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
99
+ self.audio_queue.put_nowait(audio)
 
 
 
 
100
  except websockets.exceptions.ConnectionClosedOK:
101
  print("connection closed")
102
  break
 
 
 
 
 
 
103
 
104
+ async def video_receive(self, frame: np.ndarray):
105
+ self.video_queue.put_nowait(frame)
106
+
107
+ if self.session:
108
+ # send image every 1 second
109
+ print(time.time() - self.last_frame_time)
110
+ if time.time() - self.last_frame_time > 1:
111
+ self.last_frame_time = time.time()
112
+ await self.session.send(input=encode_image(frame))
113
+ if self.latest_args[1] is not None:
114
+ await self.session.send(input=encode_image(self.latest_args[1]))
115
+
116
+ async def video_emit(self):
117
+ frame = await wait_for_item(self.video_queue, 0.01)
118
+ if frame is not None:
119
+ return frame
120
+ else:
121
+ return np.zeros((100, 100, 3), dtype=np.uint8)
122
 
123
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
124
  _, array = frame
125
  array = array.squeeze()
126
  audio_message = encode_audio(array)
127
+ if self.session:
128
+ await self.session.send(input=audio_message)
129
+
130
+ async def emit(self):
131
+ array = await wait_for_item(self.audio_queue, 0.01)
132
+ if array is not None:
133
+ return (self.output_sample_rate, array)
134
+ return array
135
+
136
+ async def shutdown(self) -> None:
137
+ if self.session:
138
+ self.quit.set()
139
+ await self.session.close()
140
+ self.quit.clear()
141
+
142
+
143
+ stream = Stream(
144
+ handler=GeminiHandler(),
145
+ modality="audio",
146
+ mode="send-receive",
147
+ rtc_configuration=get_cloudflare_turn_credentials_async,
148
+ time_limit=180 if get_space() else None,
149
+ additional_inputs=[
150
+ gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
151
+ ],
152
+ ui_args={
153
+ "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
154
+ "pulse_color": "rgb(255, 255, 255)",
155
+ "icon_button_color": "rgb(255, 255, 255)",
156
+ "title": "Gemini Audio Video Chat",
157
+ },
158
+ )
159
+
160
+ css = """
161
+ #video-source {max-width: 500px !important; max-height: 500px !important;}
162
+ """
163
+
164
+ with gr.Blocks(css=css) as demo:
165
+ gr.HTML(
166
+ """
167
+ <div>
168
+ <center>
169
+ <h1>Gen AI Voice Chat</h1>
170
+ <p>real-time audio streaming</p>
171
+ </center>
172
+ </div>
173
+ """
174
+ )
175
+ with gr.Row() as row:
176
+ with gr.Column():
177
+ webrtc = WebRTC(
178
+ label="Voice Chat",
179
+ modality="audio",
180
+ mode="send-receive",
181
+ elem_id="video-source",
182
+ rtc_configuration=get_cloudflare_turn_credentials_async,
183
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
184
+ pulse_color="rgb(255, 255, 255)",
185
+ icon_button_color="rgb(255, 255, 255)",
186
+ )
187
+ #with gr.Column():
188
+ #image_input = gr.Image(
189
+ #label="Image", type="numpy", sources=["upload", "clipboard"]
190
+ #)
191
+
192
+ webrtc.stream(
193
+ GeminiHandler(),
194
+ inputs=[webrtc],
195
+ outputs=[webrtc],
196
+ time_limit=180 if get_space() else None,
197
+ concurrency_limit=2 if get_space() else None,
198
+ )
199
+
200
+ stream.ui = demo
201
+
202
+
203
+ if __name__ == "__main__":
204
+ if (mode := os.getenv("MODE")) == "UI":
205
+ stream.ui.launch(server_port=7860)
206
+ elif mode == "PHONE":
207
+ raise ValueError("Phone mode not supported for this demo")
208
+ else:
209
+ stream.ui.launch(server_port=7860)