mgokg commited on
Commit
f43c677
·
verified ·
1 Parent(s): 6904a70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -168
app.py CHANGED
@@ -1,240 +1,203 @@
1
  import asyncio
2
  import base64
 
3
  import os
4
- import time
5
- from io import BytesIO
6
 
7
  import gradio as gr
8
  import numpy as np
9
- import websockets
10
  from dotenv import load_dotenv
 
 
11
  from fastrtc import (
12
- AsyncAudioVideoStreamHandler,
13
  Stream,
14
- WebRTC,
15
  get_cloudflare_turn_credentials_async,
16
  wait_for_item,
17
  )
18
  from google import genai
19
- from google.genai import types # Import the types module
 
 
 
 
 
 
 
 
 
 
20
  from gradio.utils import get_space
21
- from PIL import Image
22
-
23
- load_dotenv()
24
 
 
25
 
26
- system_message = "you are a helpful assistant."
27
- #system_message = "Du bist ein echzeitübersetzer. übersetze deutsch auf italienisch und italienisch auf deutsch. erkläre nichts, kommentiere nichts, füge nichts hinzu, nur übersetzen."
28
 
29
 
30
- def encode_audio(data: np.ndarray) -> dict:
31
  """Encode Audio data to send to the server"""
32
- return {
33
- "mime_type": "audio/pcm",
34
- "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
35
- }
36
 
37
 
38
- def encode_image(data: np.ndarray) -> dict:
39
- with BytesIO() as output_bytes:
40
- pil_image = Image.fromarray(data)
41
- pil_image.save(output_bytes, "JPEG")
42
- bytes_data = output_bytes.getvalue()
43
- base64_str = str(base64.b64encode(bytes_data), "utf-8")
44
- return {"mime_type": "image/jpeg", "data": base64_str}
45
 
46
-
47
- class GeminiHandler(AsyncAudioVideoStreamHandler):
48
  def __init__(
49
  self,
 
 
50
  ) -> None:
51
  super().__init__(
52
- "mono",
53
- output_sample_rate=24000,
54
  input_sample_rate=16000,
55
  )
56
- self.audio_queue = asyncio.Queue()
57
- self.video_queue = asyncio.Queue()
58
- self.session = None
59
- self.last_frame_time = 0
60
- self.quit = asyncio.Event()
61
 
62
  def copy(self) -> "GeminiHandler":
63
- return GeminiHandler()
 
 
 
64
 
65
  async def start_up(self):
 
 
 
 
 
 
66
  client = genai.Client(
67
- api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
 
68
  )
69
 
70
  # Define the tools and system instruction
71
  tools = [
72
- types.Tool(google_search=types.GoogleSearch()),
73
  ]
74
-
75
- system_instruction = types.Content(
76
- parts=[types.Part.from_text(text=f"{system_message}")],
77
  role="user"
78
  )
79
 
80
- # Update the config to include tools and system_instruction
81
- config = types.LiveConnectConfig(
82
- response_modalities=["AUDIO"],
83
- speech_config=types.SpeechConfig(
84
  language_code="de-DE",
85
- voice_config=types.VoiceConfig(
86
- prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Kore")
 
 
87
  )
88
  ),
89
  tools=tools,
90
  system_instruction=system_instruction,
91
  )
92
-
93
  async with client.aio.live.connect(
94
- #model="gemini-2.0-flash-exp",
95
- model = "models/gemini-2.5-flash-preview-native-audio-dialog",
96
- config=config, # type: ignore
97
  ) as session:
98
- self.session = session
99
- while not self.quit.is_set():
100
- turn = self.session.receive()
101
- try:
102
- async for response in turn:
103
- # Check if data exists before trying to process it as audio
104
- if data := response.data:
105
- audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
106
- self.audio_queue.put_nowait(audio) # Only put if audio was created
107
- # You might want to handle other parts of the response here
108
- # e.g., response.text, response.tool_code, etc.
109
- # For now, we just ensure we don't crash if data is None.
110
-
111
- except websockets.exceptions.ConnectionClosedOK:
112
- print("connection closed")
113
- break
114
- except Exception as e:
115
- # Catch other potential errors during response processing
116
- print(f"Error processing response: {e}")
117
- # Depending on the error, you might want to break or continue
118
- # For now, let's break to prevent infinite loops on persistent errors
119
- break
120
-
121
-
122
- async def video_receive(self, frame: np.ndarray):
123
- self.video_queue.put_nowait(frame)
124
-
125
- if self.session:
126
- # send image every 1 second
127
- print(time.time() - self.last_frame_time)
128
- if time.time() - self.last_frame_time > 1:
129
- self.last_frame_time = time.time()
130
- await self.session.send(input=encode_image(frame))
131
- if self.latest_args[1] is not None:
132
- await self.session.send(input=encode_image(self.latest_args[1]))
133
-
134
- async def video_emit(self):
135
- frame = await wait_for_item(self.video_queue, 0.01)
136
- if frame is not None:
137
- return frame
138
- else:
139
- return np.zeros((100, 100, 3), dtype=np.uint8)
140
 
141
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
142
  _, array = frame
143
  array = array.squeeze()
144
  audio_message = encode_audio(array)
145
- # Add a check to ensure the session is still active before sending
146
- if self.session:# and not self.session._ws.close: # Check if session exists and websocket is not closed
147
- try:
148
- await self.session.send(input=audio_message)
149
- except websockets.exceptions.ConnectionClosedOK:
150
- print("Attempted to send on a closed connection.")
151
- except Exception as e:
152
- print(f"Error sending audio message: {e}")
153
- else:
154
- print("Session not active, cannot send audio message.")
155
 
 
 
156
 
157
- async def emit(self):
158
- array = await wait_for_item(self.audio_queue, 0.01)
159
- if array is not None:
160
- return (self.output_sample_rate, array)
161
- return array
162
-
163
- async def shutdown(self) -> None:
164
- if self.session:
165
- self.quit.set()
166
- await self.session.close()
167
- self.quit.clear()
168
 
169
 
170
  stream = Stream(
171
- handler=GeminiHandler(),
172
- modality="audio-video",
173
  mode="send-receive",
174
- rtc_configuration=get_cloudflare_turn_credentials_async,
175
- time_limit=1800 if get_space() else None,
 
 
176
  additional_inputs=[
177
- gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  ],
179
- ui_args={
180
- "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
181
- "pulse_color": "rgb(255, 255, 255)",
182
- "icon_button_color": "rgb(255, 255, 255)",
183
- "title": "Gemini Audio Video Chat",
184
- },
185
  )
186
 
187
- css = """
188
- #video-source {max-width: 500px !important; max-height: 500px !important; background-color: #0f0f11 }
189
- #video-source video {
190
- background-color: black !important;
191
- }
192
- """
193
-
194
- with gr.Blocks(css=css) as demo:
195
- gr.HTML(
196
- """
197
- <div>
198
- <center>
199
-
200
- </center>
201
- </div>
202
- """
203
- )
204
-
205
- with gr.Row() as row:
206
- with gr.Column():
207
- webrtc = WebRTC(
208
- label="Voice Chat",
209
- modality="audio",
210
- mode="send-receive",
211
- elem_id="video-source",
212
- rtc_configuration=get_cloudflare_turn_credentials_async,
213
- icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
214
- pulse_color="rgb(255, 255, 255)",
215
- icon_button_color="rgb(255, 255, 255)",
216
- )
217
- #with gr.Column():
218
- #image_input = gr.Image(
219
- #label="Image", type="numpy", sources=["upload", "clipboard"]
220
- #)
221
-
222
- webrtc.stream(
223
- GeminiHandler(),
224
- inputs=[webrtc],
225
- outputs=[webrtc],
226
- time_limit=1800 if get_space() else None,
227
- concurrency_limit=2 if get_space() else None,
228
- )
229
 
230
- stream.ui = demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  if __name__ == "__main__":
 
 
234
  if (mode := os.getenv("MODE")) == "UI":
235
  stream.ui.launch(server_port=7860)
236
  elif mode == "PHONE":
237
- raise ValueError("Phone mode not supported for this demo")
238
  else:
239
- stream.ui.launch(server_port=7860)
240
-
 
 
1
  import asyncio
2
  import base64
3
+ import json
4
  import os
5
+ import pathlib
6
+ from typing import AsyncGenerator, Literal
7
 
8
  import gradio as gr
9
  import numpy as np
 
10
  from dotenv import load_dotenv
11
+ from fastapi import FastAPI
12
+ from fastapi.responses import HTMLResponse
13
  from fastrtc import (
14
+ AsyncStreamHandler,
15
  Stream,
 
16
  get_cloudflare_turn_credentials_async,
17
  wait_for_item,
18
  )
19
  from google import genai
20
+ from google.genai import types
21
+ from google.genai.types import (
22
+ LiveConnectConfig,
23
+ PrebuiltVoiceConfig,
24
+ SpeechConfig,
25
+ VoiceConfig,
26
+ Tool,
27
+ GoogleSearch,
28
+ Content,
29
+ Part,
30
+ )
31
  from gradio.utils import get_space
32
+ from pydantic import BaseModel
 
 
33
 
34
+ current_dir = pathlib.Path(__file__).parent
35
 
36
+ load_dotenv()
 
37
 
38
 
39
+ def encode_audio(data: np.ndarray) -> str:
40
  """Encode Audio data to send to the server"""
41
+ return base64.b64encode(data.tobytes()).decode("UTF-8")
 
 
 
42
 
43
 
44
+ class GeminiHandler(AsyncStreamHandler):
45
+ """Handler for the Gemini API"""
 
 
 
 
 
46
 
 
 
47
  def __init__(
48
  self,
49
+ expected_layout: Literal["mono"] = "mono",
50
+ output_sample_rate: int = 24000,
51
  ) -> None:
52
  super().__init__(
53
+ expected_layout,
54
+ output_sample_rate,
55
  input_sample_rate=16000,
56
  )
57
+ self.input_queue: asyncio.Queue = asyncio.Queue()
58
+ self.output_queue: asyncio.Queue = asyncio.Queue()
59
+ self.quit: asyncio.Event = asyncio.Event()
 
 
60
 
61
  def copy(self) -> "GeminiHandler":
62
+ return GeminiHandler(
63
+ expected_layout="mono",
64
+ output_sample_rate=self.output_sample_rate,
65
+ )
66
 
67
  async def start_up(self):
68
+ if not self.phone_mode:
69
+ await self.wait_for_args()
70
+ api_key, voice_name, system_message = self.latest_args[1:]
71
+ else:
72
+ api_key, voice_name, system_message = None, "Kore", "Du bist ein hilfsamer Assistent."
73
+
74
  client = genai.Client(
75
+ api_key=api_key or os.getenv("GEMINI_API_KEY"),
76
+ http_options={"api_version": "v1alpha"},
77
  )
78
 
79
  # Define the tools and system instruction
80
  tools = [
81
+ Tool(google_search=GoogleSearch()),
82
  ]
83
+ system_instruction = Content(
84
+ parts=[Part.from_text(text=f"{system_message}")],
 
85
  role="user"
86
  )
87
 
88
+ config = LiveConnectConfig(
89
+ response_modalities=["AUDIO"], # type: ignore
90
+ speech_config=SpeechConfig(
 
91
  language_code="de-DE",
92
+ voice_config=VoiceConfig(
93
+ prebuilt_voice_config=PrebuiltVoiceConfig(
94
+ voice_name=voice_name,
95
+ )
96
  )
97
  ),
98
  tools=tools,
99
  system_instruction=system_instruction,
100
  )
101
+
102
  async with client.aio.live.connect(
103
+ model="gemini-2.0-flash-exp", config=config
 
 
104
  ) as session:
105
+ async for audio in session.start_stream(
106
+ stream=self.stream(), mime_type="audio/pcm"
107
+ ):
108
+ if audio.data:
109
+ array = np.frombuffer(audio.data, dtype=np.int16)
110
+ self.output_queue.put_nowait((self.output_sample_rate, array))
111
+
112
+ async def stream(self) -> AsyncGenerator[bytes, None]:
113
+ while not self.quit.is_set():
114
+ try:
115
+ audio = await asyncio.wait_for(self.input_queue.get(), 0.1)
116
+ yield audio
117
+ except (asyncio.TimeoutError, TimeoutError):
118
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
121
  _, array = frame
122
  array = array.squeeze()
123
  audio_message = encode_audio(array)
124
+ self.input_queue.put_nowait(audio_message)
 
 
 
 
 
 
 
 
 
125
 
126
+ async def emit(self) -> tuple[int, np.ndarray] | None:
127
+ return await wait_for_item(self.output_queue)
128
 
129
+ def shutdown(self) -> None:
130
+ self.quit.set()
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  stream = Stream(
134
+ modality="audio",
 
135
  mode="send-receive",
136
+ handler=GeminiHandler(),
137
+ rtc_configuration=get_cloudflare_turn_credentials_async if get_space() else None,
138
+ concurrency_limit=5 if get_space() else None,
139
+ time_limit=90 if get_space() else None,
140
  additional_inputs=[
141
+ gr.Textbox(
142
+ label="API Key",
143
+ type="password",
144
+ value=os.getenv("GEMINI_API_KEY") if not get_space() else "",
145
+ ),
146
+ gr.Dropdown(
147
+ label="Voice",
148
+ choices=[
149
+ "Puck",
150
+ "Charon",
151
+ "Kore",
152
+ "Fenrir",
153
+ "Aoede",
154
+ ],
155
+ value="Kore", # Changed default to Kore
156
+ ),
157
+ gr.Textbox(
158
+ label="System Message",
159
+ placeholder="Enter system instructions for the AI...",
160
+ value="Du bist ein hilfsamer Assistent, der Fragen beantwortet und bei verschiedenen Aufgaben hilft. Du kannst bei Bedarf auch im Internet suchen, um aktuelle Informationen zu finden.",
161
+ lines=3,
162
+ ),
163
  ],
 
 
 
 
 
 
164
  )
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ class InputData(BaseModel):
168
+ webrtc_id: str
169
+ voice_name: str
170
+ api_key: str
171
+ system_message: str
172
+
173
+
174
+ app = FastAPI()
175
+
176
+ stream.mount(app)
177
+
178
+
179
+ @app.post("/input_hook")
180
+ async def _(body: InputData):
181
+ stream.set_input(body.webrtc_id, body.api_key, body.voice_name, body.system_message)
182
+ return {"status": "ok"}
183
+
184
+
185
+ @app.get("/")
186
+ async def index():
187
+ rtc_config = await get_cloudflare_turn_credentials_async() if get_space() else None
188
+ html_content = (current_dir / "index.html").read_text()
189
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
190
+ return HTMLResponse(content=html_content)
191
 
192
 
193
  if __name__ == "__main__":
194
+ import os
195
+
196
  if (mode := os.getenv("MODE")) == "UI":
197
  stream.ui.launch(server_port=7860)
198
  elif mode == "PHONE":
199
+ stream.fastphone(host="0.0.0.0", port=7860)
200
  else:
201
+ import uvicorn
202
+
203
+ uvicorn.run(app, host="0.0.0.0", port=7860)