mgokg commited on
Commit
69236b5
·
verified ·
1 Parent(s): bbf3c5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -148
app.py CHANGED
@@ -3,171 +3,184 @@ import base64
3
  import os
4
  import time
5
  from io import BytesIO
 
6
  import gradio as gr
7
  import numpy as np
8
  import websockets
9
  from dotenv import load_dotenv
10
  from fastrtc import (
11
- AsyncAudioVideoStreamHandler,
12
- Stream,
13
- WebRTC,
14
- get_cloudflare_turn_credentials_async,
15
- wait_for_item,
16
  )
17
  from google import genai
18
  from gradio.utils import get_space
19
  from PIL import Image
 
20
  load_dotenv()
 
 
21
  def encode_audio(data: np.ndarray) -> dict:
22
- """Encode Audio data to send to the server"""
23
- return {
24
- "mime_type": "audio/pcm",
25
- "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
26
- }
 
 
27
  def encode_image(data: np.ndarray) -> dict:
28
- with BytesIO() as output_bytes:
29
- pil_image = Image.fromarray(data)
30
- pil_image.save(output_bytes, "JPEG")
31
- bytes_data = output_bytes.getvalue()
32
- base64_str = str(base64.b64encode(bytes_data), "utf-8")
33
- return {"mime_type": "image/jpeg", "data": base64_str}
 
 
34
  class GeminiHandler(AsyncAudioVideoStreamHandler):
35
- def init(
36
- self,
37
- ) -> None:
38
- super().init(
39
- "mono",
40
- output_sample_rate=24000,
41
- input_sample_rate=16000,
42
- )
43
- self.audio_queue = asyncio.Queue()
44
- self.video_queue = asyncio.Queue()
45
- self.session = None
46
- self.last_frame_time = 0
47
- self.quit = asyncio.Event()
48
- def copy(self) -> "GeminiHandler":
49
- return GeminiHandler()
50
-
51
- async def start_up(self):
52
- client = genai.Client(
53
- api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
54
- )
55
- config = {"response_modalities": ["AUDIO"]}
56
- async with client.aio.live.connect(
57
- model="gemini-2.0-flash-exp",
58
- config=config, # type: ignore
59
- ) as session:
60
- self.session = session
61
- while not self.quit.is_set():
62
- turn = self.session.receive()
63
- try:
64
- async for response in turn:
65
- if data := response.data:
66
- audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
67
- self.audio_queue.put_nowait(audio)
68
- except websockets.exceptions.ConnectionClosedOK:
69
- print("connection closed")
70
- break
71
-
72
- async def video_receive(self, frame: np.ndarray):
73
- self.video_queue.put_nowait(frame)
74
-
75
- if self.session:
76
- # send image every 1 second
77
- print(time.time() - self.last_frame_time)
78
- if time.time() - self.last_frame_time > 1:
79
- self.last_frame_time = time.time()
80
- await self.session.send(input=encode_image(frame))
81
- if self.latest_args[1] is not None:
82
- await self.session.send(input=encode_image(self.latest_args[1]))
83
-
84
- async def video_emit(self):
85
- frame = await wait_for_item(self.video_queue, 0.01)
86
- if frame is not None:
87
- return frame
88
- else:
89
- return np.zeros((100, 100, 3), dtype=np.uint8)
90
-
91
- async def receive(self, frame: tuple[int, np.ndarray]) -> None:
92
- _, array = frame
93
- array = array.squeeze()
94
- audio_message = encode_audio(array)
95
- if self.session:
96
- await self.session.send(input=audio_message)
97
-
98
- async def emit(self):
99
- array = await wait_for_item(self.audio_queue, 0.01)
100
- if array is not None:
101
- return (self.output_sample_rate, array)
102
- return array
103
-
104
- async def shutdown(self) -> None:
105
- if self.session:
106
- self.quit.set()
107
- await self.session.close()
108
- self.quit.clear()
109
- Use code with caution.
 
 
110
  stream = Stream(
111
- handler=GeminiHandler(),
112
- modality="audio-video",
113
- mode="send-receive",
114
- rtc_configuration=get_cloudflare_turn_credentials_async,
115
- time_limit=180 if get_space() else None,
116
- additional_inputs=[
117
- gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
118
- ],
119
- ui_args={
120
- "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
121
- "pulse_color": "rgb(255, 255, 255)",
122
- "icon_button_color": "rgb(255, 255, 255)",
123
- "title": "Gemini Audio Video Chat",
124
- },
125
  )
 
126
  css = """
127
- #video-source {max-width: 600px !important; max-height: 600 !important;}
128
  """
 
129
  with gr.Blocks(css=css) as demo:
130
- gr.HTML(
131
- """
132
- <div>
133
- <center>
134
- <h1>Gen AI Voice Chat</h1>
135
- <p>Real-time audio + video streaming</p>
136
- <center>
137
- </div>
138
- """
139
- )
140
- with gr.Row() as row:
141
- with gr.Column():
142
- webrtc = WebRTC(
143
- label="Video Chat",
144
- modality="audio-video",
145
- mode="send-receive",
146
- elem_id="video-source",
147
- rtc_configuration=get_cloudflare_turn_credentials_async,
148
- icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
149
- pulse_color="rgb(255, 255, 255)",
150
- icon_button_color="rgb(255, 255, 255)",
151
- )
152
- #with gr.Column():
153
- #image_input = gr.Image(
154
- #label="Image", type="numpy", sources=["upload", "clipboard"]
155
- #)
156
- webrtc.stream(
157
- GeminiHandler(),
158
- inputs=[webrtc],
159
- #inputs=[webrtc, image_input],
160
- outputs=[webrtc],
161
- time_limit=180 if get_space() else None,
162
- concurrency_limit=2 if get_space() else None,
163
  )
164
- Use code with caution.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  stream.ui = demo
166
- if name == "main":
167
- if (mode := os.getenv("MODE")) == "UI":
168
- stream.ui.launch(server_port=7860)
169
- elif mode == "PHONE":
170
- raise ValueError("Phone mode not supported for this demo")
171
- else:
172
- stream.ui.launch(server_port=7860)
173
- warning
 
 
3
  import os
4
  import time
5
  from io import BytesIO
6
+
7
  import gradio as gr
8
  import numpy as np
9
  import websockets
10
  from dotenv import load_dotenv
11
  from fastrtc import (
12
+ AsyncAudioVideoStreamHandler,
13
+ Stream,
14
+ WebRTC,
15
+ get_cloudflare_turn_credentials_async,
16
+ wait_for_item,
17
  )
18
  from google import genai
19
  from gradio.utils import get_space
20
  from PIL import Image
21
+
22
  load_dotenv()
23
+
24
+
25
  def encode_audio(data: np.ndarray) -> dict:
26
+ """Encode Audio data to send to the server"""
27
+ return {
28
+ "mime_type": "audio/pcm",
29
+ "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
30
+ }
31
+
32
+
33
  def encode_image(data: np.ndarray) -> dict:
34
+ with BytesIO() as output_bytes:
35
+ pil_image = Image.fromarray(data)
36
+ pil_image.save(output_bytes, "JPEG")
37
+ bytes_data = output_bytes.getvalue()
38
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
39
+ return {"mime_type": "image/jpeg", "data": base64_str}
40
+
41
+
42
  class GeminiHandler(AsyncAudioVideoStreamHandler):
43
+ def __init__(
44
+ self,
45
+ ) -> None:
46
+ super().__init__(
47
+ "mono",
48
+ output_sample_rate=24000,
49
+ input_sample_rate=16000,
50
+ )
51
+ self.audio_queue = asyncio.Queue()
52
+ self.video_queue = asyncio.Queue()
53
+ self.session = None
54
+ self.last_frame_time = 0
55
+ self.quit = asyncio.Event()
56
+
57
+ def copy(self) -> "GeminiHandler":
58
+ return GeminiHandler()
59
+
60
+ async def start_up(self):
61
+ client = genai.Client(
62
+ api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
63
+ )
64
+ config = {"response_modalities": ["AUDIO"]}
65
+ async with client.aio.live.connect(
66
+ model="gemini-2.0-flash-exp",
67
+ config=config, # type: ignore
68
+ ) as session:
69
+ self.session = session
70
+ while not self.quit.is_set():
71
+ turn = self.session.receive()
72
+ try:
73
+ async for response in turn:
74
+ if data := response.data:
75
+ audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
76
+ self.audio_queue.put_nowait(audio)
77
+ except websockets.exceptions.ConnectionClosedOK:
78
+ print("connection closed")
79
+ break
80
+
81
+ async def video_receive(self, frame: np.ndarray):
82
+ self.video_queue.put_nowait(frame)
83
+
84
+ if self.session:
85
+ # send image every 1 second
86
+ print(time.time() - self.last_frame_time)
87
+ if time.time() - self.last_frame_time > 1:
88
+ self.last_frame_time = time.time()
89
+ await self.session.send(input=encode_image(frame))
90
+ if self.latest_args[1] is not None:
91
+ await self.session.send(input=encode_image(self.latest_args[1]))
92
+
93
+ async def video_emit(self):
94
+ frame = await wait_for_item(self.video_queue, 0.01)
95
+ if frame is not None:
96
+ return frame
97
+ else:
98
+ return np.zeros((100, 100, 3), dtype=np.uint8)
99
+
100
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
101
+ _, array = frame
102
+ array = array.squeeze()
103
+ audio_message = encode_audio(array)
104
+ if self.session:
105
+ await self.session.send(input=audio_message)
106
+
107
+ async def emit(self):
108
+ array = await wait_for_item(self.audio_queue, 0.01)
109
+ if array is not None:
110
+ return (self.output_sample_rate, array)
111
+ return array
112
+
113
+ async def shutdown(self) -> None:
114
+ if self.session:
115
+ self.quit.set()
116
+ await self.session.close()
117
+ self.quit.clear()
118
+
119
+
120
  stream = Stream(
121
+ handler=GeminiHandler(),
122
+ modality="audio-video",
123
+ mode="send-receive",
124
+ rtc_configuration=get_cloudflare_turn_credentials_async,
125
+ time_limit=180 if get_space() else None,
126
+ additional_inputs=[
127
+ gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
128
+ ],
129
+ ui_args={
130
+ "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
131
+ "pulse_color": "rgb(255, 255, 255)",
132
+ "icon_button_color": "rgb(255, 255, 255)",
133
+ "title": "Gemini Audio Video Chat",
134
+ },
135
  )
136
+
137
  css = """
138
+ #video-source {max-width: 500px !important; max-height: 500px !important;}
139
  """
140
+
141
  with gr.Blocks(css=css) as demo:
142
+ gr.HTML(
143
+ """
144
+ <div>
145
+ <center>
146
+ <h1>Gen AI Voice Chat</h1>
147
+ <p>real-time audio + video streaming</p>
148
+ </center>
149
+ </div>
150
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
+ with gr.Row() as row:
153
+ with gr.Column():
154
+ webrtc = WebRTC(
155
+ label="Video Chat",
156
+ modality="audio-video",
157
+ mode="send-receive",
158
+ elem_id="video-source",
159
+ rtc_configuration=get_cloudflare_turn_credentials_async,
160
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
161
+ pulse_color="rgb(255, 255, 255)",
162
+ icon_button_color="rgb(255, 255, 255)",
163
+ )
164
+ #with gr.Column():
165
+ #image_input = gr.Image(
166
+ #label="Image", type="numpy", sources=["upload", "clipboard"]
167
+ #)
168
+
169
+ webrtc.stream(
170
+ GeminiHandler(),
171
+ inputs=[webrtc],
172
+ outputs=[webrtc],
173
+ time_limit=180 if get_space() else None,
174
+ concurrency_limit=2 if get_space() else None,
175
+ )
176
+
177
  stream.ui = demo
178
+
179
+
180
+ if __name__ == "__main__":
181
+ if (mode := os.getenv("MODE")) == "UI":
182
+ stream.ui.launch(server_port=7860)
183
+ elif mode == "PHONE":
184
+ raise ValueError("Phone mode not supported for this demo")
185
+ else:
186
+ stream.ui.launch(server_port=7860)