mgokg commited on
Commit
bbf3c5f
·
verified ·
1 Parent(s): 5982628

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -62
app.py CHANGED
@@ -3,82 +3,171 @@ import base64
3
  import os
4
  import time
5
  from io import BytesIO
6
-
7
  import numpy as np
8
  import websockets
9
- import streamlit as st
10
  from dotenv import load_dotenv
11
- from PIL import Image
 
 
 
 
 
 
12
  from google import genai
13
-
 
14
  load_dotenv()
15
-
16
- # Helper Functions
17
  def encode_audio(data: np.ndarray) -> dict:
18
- """Encode Audio data"""
19
- return {
20
- "mime_type": "audio/pcm",
21
- "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
22
- }
23
-
24
  def encode_image(data: np.ndarray) -> dict:
25
- """Encode Image data"""
26
- with BytesIO() as output_bytes:
27
- pil_image = Image.fromarray(data)
28
- pil_image.save(output_bytes, "JPEG")
29
- bytes_data = output_bytes.getvalue()
30
- return {"mime_type": "image/jpeg", "data": base64.b64encode(bytes_data).decode("utf-8")}
31
-
32
- # Streamlit UI
33
- st.title("Gen AI Voice Chat")
34
- st.subheader("Real-time audio & video streaming")
35
-
36
- # Initialize chat history
37
- if "messages" not in st.session_state:
38
- st.session_state.messages = [
39
- {"role": "assistant", "content": "Welcome! I'm your AI assistant. I can process images and audio in real-time. How can I help you today?"}
40
- ]
 
 
 
 
 
 
41
 
42
- # Display chat messages
43
- for message in st.session_state.messages:
44
- with st.chat_message(message["role"]):
45
- st.write(message["content"])
46
-
47
- # Sidebar for image upload
48
- with st.sidebar:
49
- st.header("Configuration")
50
- uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png"])
51
- if uploaded_image:
52
- # Add user message with image
53
- st.session_state.messages.append({"role": "user", "content": "Uploaded an image"})
54
- # Display image in chat
55
- with st.chat_message("user"):
56
- st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
57
-
58
- # WebRTC Streaming Placeholder
59
- with st.expander("🎥 Live Video Stream"):
60
- st.write("WebRTC video streaming placeholder - implement your video streaming here")
61
-
62
- # Async Audio Processing
63
- async def start_audio_processing():
64
- client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
65
  async with client.aio.live.connect(
66
- model="gemini-2.0-flash-exp", config={"response_modalities": ["AUDIO"]}
 
67
  ) as session:
68
- while True:
69
- turn = session.receive()
 
70
  try:
71
  async for response in turn:
72
  if data := response.data:
73
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
74
- # Add assistant response to chat
75
- st.session_state.messages.append({"role": "assistant", "content": audio})
76
- with st.chat_message("assistant"):
77
- st.audio(audio, format="audio/wav")
78
  except websockets.exceptions.ConnectionClosedOK:
79
- st.error("Connection closed.")
80
  break
81
 
82
- # Run the Streamlit App
83
- if __name__ == "__main__":
84
- asyncio.run(start_audio_processing())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import time
5
  from io import BytesIO
6
+ import gradio as gr
7
  import numpy as np
8
  import websockets
 
9
  from dotenv import load_dotenv
10
+ from fastrtc import (
11
+ AsyncAudioVideoStreamHandler,
12
+ Stream,
13
+ WebRTC,
14
+ get_cloudflare_turn_credentials_async,
15
+ wait_for_item,
16
+ )
17
  from google import genai
18
+ from gradio.utils import get_space
19
+ from PIL import Image
20
  load_dotenv()
 
 
21
  def encode_audio(data: np.ndarray) -> dict:
22
+ """Encode Audio data to send to the server"""
23
+ return {
24
+ "mime_type": "audio/pcm",
25
+ "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
26
+ }
 
27
  def encode_image(data: np.ndarray) -> dict:
28
+ with BytesIO() as output_bytes:
29
+ pil_image = Image.fromarray(data)
30
+ pil_image.save(output_bytes, "JPEG")
31
+ bytes_data = output_bytes.getvalue()
32
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
33
+ return {"mime_type": "image/jpeg", "data": base64_str}
34
+ class GeminiHandler(AsyncAudioVideoStreamHandler):
35
+ def init(
36
+ self,
37
+ ) -> None:
38
+ super().init(
39
+ "mono",
40
+ output_sample_rate=24000,
41
+ input_sample_rate=16000,
42
+ )
43
+ self.audio_queue = asyncio.Queue()
44
+ self.video_queue = asyncio.Queue()
45
+ self.session = None
46
+ self.last_frame_time = 0
47
+ self.quit = asyncio.Event()
48
+ def copy(self) -> "GeminiHandler":
49
+ return GeminiHandler()
50
 
51
+ async def start_up(self):
52
+ client = genai.Client(
53
+ api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
54
+ )
55
+ config = {"response_modalities": ["AUDIO"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  async with client.aio.live.connect(
57
+ model="gemini-2.0-flash-exp",
58
+ config=config, # type: ignore
59
  ) as session:
60
+ self.session = session
61
+ while not self.quit.is_set():
62
+ turn = self.session.receive()
63
  try:
64
  async for response in turn:
65
  if data := response.data:
66
  audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
67
+ self.audio_queue.put_nowait(audio)
 
 
 
68
  except websockets.exceptions.ConnectionClosedOK:
69
+ print("connection closed")
70
  break
71
 
72
+ async def video_receive(self, frame: np.ndarray):
73
+ self.video_queue.put_nowait(frame)
74
+
75
+ if self.session:
76
+ # send image every 1 second
77
+ print(time.time() - self.last_frame_time)
78
+ if time.time() - self.last_frame_time > 1:
79
+ self.last_frame_time = time.time()
80
+ await self.session.send(input=encode_image(frame))
81
+ if self.latest_args[1] is not None:
82
+ await self.session.send(input=encode_image(self.latest_args[1]))
83
+
84
+ async def video_emit(self):
85
+ frame = await wait_for_item(self.video_queue, 0.01)
86
+ if frame is not None:
87
+ return frame
88
+ else:
89
+ return np.zeros((100, 100, 3), dtype=np.uint8)
90
+
91
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
92
+ _, array = frame
93
+ array = array.squeeze()
94
+ audio_message = encode_audio(array)
95
+ if self.session:
96
+ await self.session.send(input=audio_message)
97
+
98
+ async def emit(self):
99
+ array = await wait_for_item(self.audio_queue, 0.01)
100
+ if array is not None:
101
+ return (self.output_sample_rate, array)
102
+ return array
103
+
104
+ async def shutdown(self) -> None:
105
+ if self.session:
106
+ self.quit.set()
107
+ await self.session.close()
108
+ self.quit.clear()
109
+ Use code with caution.
110
+ stream = Stream(
111
+ handler=GeminiHandler(),
112
+ modality="audio-video",
113
+ mode="send-receive",
114
+ rtc_configuration=get_cloudflare_turn_credentials_async,
115
+ time_limit=180 if get_space() else None,
116
+ additional_inputs=[
117
+ gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
118
+ ],
119
+ ui_args={
120
+ "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
121
+ "pulse_color": "rgb(255, 255, 255)",
122
+ "icon_button_color": "rgb(255, 255, 255)",
123
+ "title": "Gemini Audio Video Chat",
124
+ },
125
+ )
126
+ css = """
127
+ #video-source {max-width: 600px !important; max-height: 600 !important;}
128
+ """
129
+ with gr.Blocks(css=css) as demo:
130
+ gr.HTML(
131
+ """
132
+ <div>
133
+ <center>
134
+ <h1>Gen AI Voice Chat</h1>
135
+ <p>Real-time audio + video streaming</p>
136
+ <center>
137
+ </div>
138
+ """
139
+ )
140
+ with gr.Row() as row:
141
+ with gr.Column():
142
+ webrtc = WebRTC(
143
+ label="Video Chat",
144
+ modality="audio-video",
145
+ mode="send-receive",
146
+ elem_id="video-source",
147
+ rtc_configuration=get_cloudflare_turn_credentials_async,
148
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
149
+ pulse_color="rgb(255, 255, 255)",
150
+ icon_button_color="rgb(255, 255, 255)",
151
+ )
152
+ #with gr.Column():
153
+ #image_input = gr.Image(
154
+ #label="Image", type="numpy", sources=["upload", "clipboard"]
155
+ #)
156
+ webrtc.stream(
157
+ GeminiHandler(),
158
+ inputs=[webrtc],
159
+ #inputs=[webrtc, image_input],
160
+ outputs=[webrtc],
161
+ time_limit=180 if get_space() else None,
162
+ concurrency_limit=2 if get_space() else None,
163
+ )
164
+ Use code with caution.
165
+ stream.ui = demo
166
+ if name == "main":
167
+ if (mode := os.getenv("MODE")) == "UI":
168
+ stream.ui.launch(server_port=7860)
169
+ elif mode == "PHONE":
170
+ raise ValueError("Phone mode not supported for this demo")
171
+ else:
172
+ stream.ui.launch(server_port=7860)
173
+ warning