mgokg commited on
Commit
aa93a81
·
verified ·
1 Parent(s): ed2fba6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -32
app.py CHANGED
@@ -22,9 +22,7 @@ from PIL import Image
22
 
23
  load_dotenv()
24
 
25
- system_message = "you are a helpful assistant."
26
- #system_message = "Du bist ein echzeitübersetzer. übersetze deutsch auf italienisch und italienisch auf deutsch. erkläre nichts, kommentiere nichts, füge nichts hinzu, nur übersetzen."
27
-
28
 
29
  def encode_audio(data: np.ndarray) -> dict:
30
  """Encode Audio data to send to the server"""
@@ -46,6 +44,7 @@ def encode_image(data: np.ndarray) -> dict:
46
  class GeminiHandler(AsyncAudioVideoStreamHandler):
47
  def __init__(
48
  self,
 
49
  ) -> None:
50
  super().__init__(
51
  "mono",
@@ -57,9 +56,10 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
57
  self.session = None
58
  self.last_frame_time = 0
59
  self.quit = asyncio.Event()
 
60
 
61
  def copy(self) -> "GeminiHandler":
62
- return GeminiHandler()
63
 
64
  async def start_up(self):
65
  client = genai.Client(
@@ -72,7 +72,7 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
72
  ]
73
 
74
  system_instruction = types.Content(
75
- parts=[types.Part.from_text(text=f"{system_message}")],
76
  role="user"
77
  )
78
 
@@ -165,23 +165,6 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
165
  self.quit.clear()
166
 
167
 
168
- stream = Stream(
169
- handler=GeminiHandler(),
170
- modality="audio",
171
- mode="send-receive",
172
- rtc_configuration=get_cloudflare_turn_credentials_async,
173
- time_limit=180 if get_space() else None,
174
- additional_inputs=[
175
- gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
176
- ],
177
- ui_args={
178
- "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
179
- "pulse_color": "rgb(255, 255, 255)",
180
- "icon_button_color": "rgb(255, 255, 255)",
181
- "title": "Gemini Audio Video Chat",
182
- },
183
- )
184
-
185
  css = """
186
  #video-source {max-width: 500px !important; max-height: 500px !important; background-color: #0f0f11 }
187
  #video-source video {
@@ -202,6 +185,9 @@ with gr.Blocks(css=css) as demo:
202
  )
203
  with gr.Row() as row:
204
  with gr.Column():
 
 
 
205
  webrtc = WebRTC(
206
  label="Voice Chat",
207
  modality="audio",
@@ -212,27 +198,36 @@ with gr.Blocks(css=css) as demo:
212
  pulse_color="rgb(255, 255, 255)",
213
  icon_button_color="rgb(255, 255, 255)",
214
  )
215
- #with gr.Column():
216
- #image_input = gr.Image(
217
- #label="Image", type="numpy", sources=["upload", "clipboard"]
218
- #)
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  webrtc.stream(
221
- GeminiHandler(),
222
  inputs=[webrtc],
223
  outputs=[webrtc],
224
  time_limit=180 if get_space() else None,
225
  concurrency_limit=2 if get_space() else None,
226
  )
227
 
228
- stream.ui = demo
229
-
230
 
231
  if __name__ == "__main__":
232
  if (mode := os.getenv("MODE")) == "UI":
233
- stream.ui.launch(server_port=7860)
234
  elif mode == "PHONE":
235
  raise ValueError("Phone mode not supported for this demo")
236
  else:
237
- stream.ui.launch(server_port=7860)
238
-
 
22
 
23
  load_dotenv()
24
 
25
+ # system_message will be set based on the user's selection
 
 
26
 
27
  def encode_audio(data: np.ndarray) -> dict:
28
  """Encode Audio data to send to the server"""
 
44
  class GeminiHandler(AsyncAudioVideoStreamHandler):
45
  def __init__(
46
  self,
47
+ system_message: str, # Add system_message as an argument
48
  ) -> None:
49
  super().__init__(
50
  "mono",
 
56
  self.session = None
57
  self.last_frame_time = 0
58
  self.quit = asyncio.Event()
59
+ self.system_message = system_message # Store the system message
60
 
61
  def copy(self) -> "GeminiHandler":
62
+ return GeminiHandler(self.system_message) # Pass the system message when copying
63
 
64
  async def start_up(self):
65
  client = genai.Client(
 
72
  ]
73
 
74
  system_instruction = types.Content(
75
+ parts=[types.Part.from_text(text=f"{self.system_message}")], # Use the stored system message
76
  role="user"
77
  )
78
 
 
165
  self.quit.clear()
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  css = """
169
  #video-source {max-width: 500px !important; max-height: 500px !important; background-color: #0f0f11 }
170
  #video-source video {
 
185
  )
186
  with gr.Row() as row:
187
  with gr.Column():
188
+ mode_selector = gr.Radio(
189
+ ["Chat", "Translate"], label="Select Mode", value="Chat"
190
+ )
191
  webrtc = WebRTC(
192
  label="Voice Chat",
193
  modality="audio",
 
198
  pulse_color="rgb(255, 255, 255)",
199
  icon_button_color="rgb(255, 255, 255)",
200
  )
 
 
 
 
201
 
202
+ def update_handler(mode):
203
+ if mode == "Chat":
204
+ system_message = "you are a helpful assistant."
205
+ elif mode == "Translate":
206
+ system_message = "Du bist ein echzeitübersetzer. übersetze deutsch auf italienisch und italienisch auf deutsch. erkläre nichts, kommentiere nichts, füge nichts hinzu, nur übersetzen."
207
+ return GeminiHandler(system_message=system_message)
208
+
209
+ mode_selector.change(
210
+ update_handler,
211
+ inputs=[mode_selector],
212
+ outputs=[webrtc], # This will trigger a restart of the WebRTC component with the new handler
213
+ queue=False # Don't queue this event, it should happen immediately
214
+ )
215
+
216
+ # Initial setup of the handler based on the default mode
217
+ initial_system_message = "you are a helpful assistant."
218
  webrtc.stream(
219
+ GeminiHandler(system_message=initial_system_message),
220
  inputs=[webrtc],
221
  outputs=[webrtc],
222
  time_limit=180 if get_space() else None,
223
  concurrency_limit=2 if get_space() else None,
224
  )
225
 
 
 
226
 
227
  if __name__ == "__main__":
228
  if (mode := os.getenv("MODE")) == "UI":
229
+ demo.launch(server_port=7860)
230
  elif mode == "PHONE":
231
  raise ValueError("Phone mode not supported for this demo")
232
  else:
233
+ demo.launch(server_port=7860)