mgokg commited on
Commit
30a15aa
·
verified ·
1 Parent(s): cfcda0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -161
app.py CHANGED
@@ -4,183 +4,81 @@ import os
4
  import time
5
  from io import BytesIO
6
 
7
- import gradio as gr
8
  import numpy as np
9
  import websockets
 
10
  from dotenv import load_dotenv
11
- from fastrtc import (
12
- AsyncAudioVideoStreamHandler,
13
- Stream,
14
- WebRTC,
15
- get_cloudflare_turn_credentials_async,
16
- wait_for_item,
17
- )
18
- from google import genai
19
- from gradio.utils import get_space
20
  from PIL import Image
 
21
 
22
  load_dotenv()
23
 
24
-
25
  def encode_audio(data: np.ndarray) -> dict:
26
- """Encode Audio data to send to the server"""
27
  return {
28
  "mime_type": "audio/pcm",
29
  "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
30
  }
31
 
32
-
33
  def encode_image(data: np.ndarray) -> dict:
 
34
  with BytesIO() as output_bytes:
35
  pil_image = Image.fromarray(data)
36
  pil_image.save(output_bytes, "JPEG")
37
  bytes_data = output_bytes.getvalue()
38
- base64_str = str(base64.b64encode(bytes_data), "utf-8")
39
- return {"mime_type": "image/jpeg", "data": base64_str}
40
-
41
-
42
- class GeminiHandler(AsyncAudioVideoStreamHandler):
43
- def __init__(
44
- self,
45
- ) -> None:
46
- super().__init__(
47
- "mono",
48
- output_sample_rate=24000,
49
- input_sample_rate=16000,
50
- )
51
- self.audio_queue = asyncio.Queue()
52
- self.video_queue = asyncio.Queue()
53
- self.session = None
54
- self.last_frame_time = 0
55
- self.quit = asyncio.Event()
56
-
57
- def copy(self) -> "GeminiHandler":
58
- return GeminiHandler()
59
-
60
- async def start_up(self):
61
- client = genai.Client(
62
- api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
63
- )
64
- config = {"response_modalities": ["AUDIO"]}
65
- async with client.aio.live.connect(
66
- model="gemini-2.0-flash-exp",
67
- config=config, # type: ignore
68
- ) as session:
69
- self.session = session
70
- while not self.quit.is_set():
71
- turn = self.session.receive()
72
- try:
73
- async for response in turn:
74
- if data := response.data:
75
- audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
76
- self.audio_queue.put_nowait(audio)
77
- except websockets.exceptions.ConnectionClosedOK:
78
- print("connection closed")
79
- break
80
-
81
- async def video_receive(self, frame: np.ndarray):
82
- self.video_queue.put_nowait(frame)
83
-
84
- if self.session:
85
- # send image every 1 second
86
- print(time.time() - self.last_frame_time)
87
- if time.time() - self.last_frame_time > 1:
88
- self.last_frame_time = time.time()
89
- await self.session.send(input=encode_image(frame))
90
- if self.latest_args[1] is not None:
91
- await self.session.send(input=encode_image(self.latest_args[1]))
92
-
93
- async def video_emit(self):
94
- frame = await wait_for_item(self.video_queue, 0.01)
95
- if frame is not None:
96
- return frame
97
- else:
98
- return np.zeros((100, 100, 3), dtype=np.uint8)
99
-
100
- async def receive(self, frame: tuple[int, np.ndarray]) -> None:
101
- _, array = frame
102
- array = array.squeeze()
103
- audio_message = encode_audio(array)
104
- if self.session:
105
- await self.session.send(input=audio_message)
106
-
107
- async def emit(self):
108
- array = await wait_for_item(self.audio_queue, 0.01)
109
- if array is not None:
110
- return (self.output_sample_rate, array)
111
- return array
112
-
113
- async def shutdown(self) -> None:
114
- if self.session:
115
- self.quit.set()
116
- await self.session.close()
117
- self.quit.clear()
118
-
119
-
120
- stream = Stream(
121
- handler=GeminiHandler(),
122
- modality="audio-video",
123
- mode="send-receive",
124
- rtc_configuration=get_cloudflare_turn_credentials_async,
125
- time_limit=180 if get_space() else None,
126
- additional_inputs=[
127
- gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
128
- ],
129
- ui_args={
130
- "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
131
- "pulse_color": "rgb(255, 255, 255)",
132
- "icon_button_color": "rgb(255, 255, 255)",
133
- "title": "Gemini Audio Video Chat",
134
- },
135
- )
136
-
137
- css = """
138
- #video-source {max-width: 500px !important; max-height: 500px !important;}
139
- """
140
-
141
- with gr.Blocks(css=css) as demo:
142
- gr.HTML(
143
- """
144
- <div>
145
- <center>
146
- <h1>Gen AI Voice Chat</h1>
147
- <p>real-time audio + video streaming</p>
148
- </center>
149
- </div>
150
- """
151
- )
152
- with gr.Row() as row:
153
- with gr.Column():
154
- webrtc = WebRTC(
155
- label="Video Chat",
156
- modality="audio-video",
157
- mode="send-receive",
158
- elem_id="video-source",
159
- rtc_configuration=get_cloudflare_turn_credentials_async,
160
- icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
161
- pulse_color="rgb(255, 255, 255)",
162
- icon_button_color="rgb(255, 255, 255)",
163
- )
164
- #with gr.Column():
165
- #image_input = gr.Image(
166
- #label="Image", type="numpy", sources=["upload", "clipboard"]
167
- #)
168
-
169
- webrtc.stream(
170
- GeminiHandler(),
171
- inputs=[webrtc],
172
- outputs=[webrtc],
173
- time_limit=180 if get_space() else None,
174
- concurrency_limit=2 if get_space() else None,
175
- )
176
-
177
- stream.ui = demo
178
-
179
-
180
  if __name__ == "__main__":
181
- if (mode := os.getenv("MODE")) == "UI":
182
- stream.ui.launch(server_port=7860)
183
- elif mode == "PHONE":
184
- raise ValueError("Phone mode not supported for this demo")
185
- else:
186
- stream.ui.launch(server_port=7860)
 
4
  import time
5
  from io import BytesIO
6
 
 
7
  import numpy as np
8
  import websockets
9
+ import streamlit as st
10
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
11
  from PIL import Image
12
+ from google import genai
13
 
14
  load_dotenv()
15
 
16
+ # Helper Functions
17
  def encode_audio(data: np.ndarray) -> dict:
18
+ """Encode Audio data"""
19
  return {
20
  "mime_type": "audio/pcm",
21
  "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
22
  }
23
 
 
24
  def encode_image(data: np.ndarray) -> dict:
25
+ """Encode Image data"""
26
  with BytesIO() as output_bytes:
27
  pil_image = Image.fromarray(data)
28
  pil_image.save(output_bytes, "JPEG")
29
  bytes_data = output_bytes.getvalue()
30
+ return {"mime_type": "image/jpeg", "data": base64.b64encode(bytes_data).decode("utf-8")}
31
+
32
+ # Streamlit UI
33
+ st.title("Gen AI Voice Chat")
34
+ st.subheader("Real-time audio & video streaming")
35
+
36
+ # Initialize chat history
37
+ if "messages" not in st.session_state:
38
+ st.session_state.messages = [
39
+ {"role": "assistant", "content": "Welcome! I'm your AI assistant. I can process images and audio in real-time. How can I help you today?"}
40
+ ]
41
+
42
+ # Display chat messages
43
+ for message in st.session_state.messages:
44
+ with st.chat_message(message["role"]):
45
+ st.write(message["content"])
46
+
47
+ # Sidebar for image upload
48
+ with st.sidebar:
49
+ st.header("Configuration")
50
+ uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png"])
51
+ if uploaded_image:
52
+ # Add user message with image
53
+ st.session_state.messages.append({"role": "user", "content": "Uploaded an image"})
54
+ # Display image in chat
55
+ with st.chat_message("user"):
56
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
57
+
58
+ # WebRTC Streaming Placeholder
59
+ with st.expander("🎥 Live Video Stream"):
60
+ st.write("WebRTC video streaming placeholder - implement your video streaming here")
61
+
62
+ # Async Audio Processing
63
+ async def start_audio_processing():
64
+ client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
65
+ async with client.aio.live.connect(
66
+ model="gemini-2.0-flash-exp", config={"response_modalities": ["AUDIO"]}
67
+ ) as session:
68
+ while True:
69
+ turn = session.receive()
70
+ try:
71
+ async for response in turn:
72
+ if data := response.data:
73
+ audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
74
+ # Add assistant response to chat
75
+ st.session_state.messages.append({"role": "assistant", "content": audio})
76
+ with st.chat_message("assistant"):
77
+ st.audio(audio, format="audio/wav")
78
+ except websockets.exceptions.ConnectionClosedOK:
79
+ st.error("Connection closed.")
80
+ break
81
+
82
+ # Run the Streamlit App
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if __name__ == "__main__":
84
+ asyncio.run(start_audio_processing())