Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -22,9 +22,7 @@ from PIL import Image
|
|
22 |
|
23 |
load_dotenv()
|
24 |
|
25 |
-
system_message
|
26 |
-
#system_message = "Du bist ein echzeitübersetzer. übersetze deutsch auf italienisch und italienisch auf deutsch. erkläre nichts, kommentiere nichts, füge nichts hinzu, nur übersetzen."
|
27 |
-
|
28 |
|
29 |
def encode_audio(data: np.ndarray) -> dict:
|
30 |
"""Encode Audio data to send to the server"""
|
@@ -46,6 +44,7 @@ def encode_image(data: np.ndarray) -> dict:
|
|
46 |
class GeminiHandler(AsyncAudioVideoStreamHandler):
|
47 |
def __init__(
|
48 |
self,
|
|
|
49 |
) -> None:
|
50 |
super().__init__(
|
51 |
"mono",
|
@@ -57,9 +56,10 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|
57 |
self.session = None
|
58 |
self.last_frame_time = 0
|
59 |
self.quit = asyncio.Event()
|
|
|
60 |
|
61 |
def copy(self) -> "GeminiHandler":
|
62 |
-
return GeminiHandler()
|
63 |
|
64 |
async def start_up(self):
|
65 |
client = genai.Client(
|
@@ -72,7 +72,7 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|
72 |
]
|
73 |
|
74 |
system_instruction = types.Content(
|
75 |
-
parts=[types.Part.from_text(text=f"{system_message}")],
|
76 |
role="user"
|
77 |
)
|
78 |
|
@@ -165,23 +165,6 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|
165 |
self.quit.clear()
|
166 |
|
167 |
|
168 |
-
stream = Stream(
|
169 |
-
handler=GeminiHandler(),
|
170 |
-
modality="audio",
|
171 |
-
mode="send-receive",
|
172 |
-
rtc_configuration=get_cloudflare_turn_credentials_async,
|
173 |
-
time_limit=180 if get_space() else None,
|
174 |
-
additional_inputs=[
|
175 |
-
gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
|
176 |
-
],
|
177 |
-
ui_args={
|
178 |
-
"icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
|
179 |
-
"pulse_color": "rgb(255, 255, 255)",
|
180 |
-
"icon_button_color": "rgb(255, 255, 255)",
|
181 |
-
"title": "Gemini Audio Video Chat",
|
182 |
-
},
|
183 |
-
)
|
184 |
-
|
185 |
css = """
|
186 |
#video-source {max-width: 500px !important; max-height: 500px !important; background-color: #0f0f11 }
|
187 |
#video-source video {
|
@@ -202,6 +185,9 @@ with gr.Blocks(css=css) as demo:
|
|
202 |
)
|
203 |
with gr.Row() as row:
|
204 |
with gr.Column():
|
|
|
|
|
|
|
205 |
webrtc = WebRTC(
|
206 |
label="Voice Chat",
|
207 |
modality="audio",
|
@@ -212,27 +198,36 @@ with gr.Blocks(css=css) as demo:
|
|
212 |
pulse_color="rgb(255, 255, 255)",
|
213 |
icon_button_color="rgb(255, 255, 255)",
|
214 |
)
|
215 |
-
#with gr.Column():
|
216 |
-
#image_input = gr.Image(
|
217 |
-
#label="Image", type="numpy", sources=["upload", "clipboard"]
|
218 |
-
#)
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
webrtc.stream(
|
221 |
-
GeminiHandler(),
|
222 |
inputs=[webrtc],
|
223 |
outputs=[webrtc],
|
224 |
time_limit=180 if get_space() else None,
|
225 |
concurrency_limit=2 if get_space() else None,
|
226 |
)
|
227 |
|
228 |
-
stream.ui = demo
|
229 |
-
|
230 |
|
231 |
if __name__ == "__main__":
|
232 |
if (mode := os.getenv("MODE")) == "UI":
|
233 |
-
|
234 |
elif mode == "PHONE":
|
235 |
raise ValueError("Phone mode not supported for this demo")
|
236 |
else:
|
237 |
-
|
238 |
-
|
|
|
22 |
|
23 |
load_dotenv()
|
24 |
|
25 |
+
# system_message will be set based on the user's selection
|
|
|
|
|
26 |
|
27 |
def encode_audio(data: np.ndarray) -> dict:
|
28 |
"""Encode Audio data to send to the server"""
|
|
|
44 |
class GeminiHandler(AsyncAudioVideoStreamHandler):
|
45 |
def __init__(
|
46 |
self,
|
47 |
+
system_message: str, # Add system_message as an argument
|
48 |
) -> None:
|
49 |
super().__init__(
|
50 |
"mono",
|
|
|
56 |
self.session = None
|
57 |
self.last_frame_time = 0
|
58 |
self.quit = asyncio.Event()
|
59 |
+
self.system_message = system_message # Store the system message
|
60 |
|
61 |
def copy(self) -> "GeminiHandler":
|
62 |
+
return GeminiHandler(self.system_message) # Pass the system message when copying
|
63 |
|
64 |
async def start_up(self):
|
65 |
client = genai.Client(
|
|
|
72 |
]
|
73 |
|
74 |
system_instruction = types.Content(
|
75 |
+
parts=[types.Part.from_text(text=f"{self.system_message}")], # Use the stored system message
|
76 |
role="user"
|
77 |
)
|
78 |
|
|
|
165 |
self.quit.clear()
|
166 |
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
css = """
|
169 |
#video-source {max-width: 500px !important; max-height: 500px !important; background-color: #0f0f11 }
|
170 |
#video-source video {
|
|
|
185 |
)
|
186 |
with gr.Row() as row:
|
187 |
with gr.Column():
|
188 |
+
mode_selector = gr.Radio(
|
189 |
+
["Chat", "Translate"], label="Select Mode", value="Chat"
|
190 |
+
)
|
191 |
webrtc = WebRTC(
|
192 |
label="Voice Chat",
|
193 |
modality="audio",
|
|
|
198 |
pulse_color="rgb(255, 255, 255)",
|
199 |
icon_button_color="rgb(255, 255, 255)",
|
200 |
)
|
|
|
|
|
|
|
|
|
201 |
|
202 |
+
def update_handler(mode):
|
203 |
+
if mode == "Chat":
|
204 |
+
system_message = "you are a helpful assistant."
|
205 |
+
elif mode == "Translate":
|
206 |
+
system_message = "Du bist ein echzeitübersetzer. übersetze deutsch auf italienisch und italienisch auf deutsch. erkläre nichts, kommentiere nichts, füge nichts hinzu, nur übersetzen."
|
207 |
+
return GeminiHandler(system_message=system_message)
|
208 |
+
|
209 |
+
mode_selector.change(
|
210 |
+
update_handler,
|
211 |
+
inputs=[mode_selector],
|
212 |
+
outputs=[webrtc], # This will trigger a restart of the WebRTC component with the new handler
|
213 |
+
queue=False # Don't queue this event, it should happen immediately
|
214 |
+
)
|
215 |
+
|
216 |
+
# Initial setup of the handler based on the default mode
|
217 |
+
initial_system_message = "you are a helpful assistant."
|
218 |
webrtc.stream(
|
219 |
+
GeminiHandler(system_message=initial_system_message),
|
220 |
inputs=[webrtc],
|
221 |
outputs=[webrtc],
|
222 |
time_limit=180 if get_space() else None,
|
223 |
concurrency_limit=2 if get_space() else None,
|
224 |
)
|
225 |
|
|
|
|
|
226 |
|
227 |
if __name__ == "__main__":
|
228 |
if (mode := os.getenv("MODE")) == "UI":
|
229 |
+
demo.launch(server_port=7860)
|
230 |
elif mode == "PHONE":
|
231 |
raise ValueError("Phone mode not supported for this demo")
|
232 |
else:
|
233 |
+
demo.launch(server_port=7860)
|
|