freddyaboulton HF staff commited on
Commit
b141d5b
·
verified ·
1 Parent(s): 93cb0e6

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +37 -32
  2. index.html +66 -26
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
  from pathlib import Path
4
 
5
  import anthropic
@@ -13,6 +14,7 @@ from fastrtc import (
13
  AdditionalOutputs,
14
  ReplyOnPause,
15
  Stream,
 
16
  get_tts_model,
17
  get_twilio_turn_credentials,
18
  )
@@ -36,38 +38,41 @@ def response(
36
  audio: tuple[int, np.ndarray],
37
  chatbot: list[dict] | None = None,
38
  ):
39
- chatbot = chatbot or []
40
- messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
41
- prompt = groq_client.audio.transcriptions.create(
42
- file=("audio-file.mp3", audio_to_bytes(audio)),
43
- model="whisper-large-v3-turbo",
44
- response_format="verbose_json",
45
- ).text
46
- print("prompt", prompt)
47
- chatbot.append({"role": "user", "content": prompt})
48
- yield AdditionalOutputs(chatbot)
49
- messages.append({"role": "user", "content": prompt})
50
- response = claude_client.messages.create(
51
- model="claude-3-5-haiku-20241022",
52
- max_tokens=512,
53
- messages=messages, # type: ignore
54
- )
55
- response_text = " ".join(
56
- block.text # type: ignore
57
- for block in response.content
58
- if getattr(block, "type", None) == "text"
59
- )
60
- chatbot.append({"role": "assistant", "content": response_text})
61
- import time
62
-
63
- start = time.time()
64
-
65
- print("starting tts", start)
66
- for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
67
- print("chunk", i, time.time() - start)
68
- yield chunk
69
- print("finished tts", time.time() - start)
70
- yield AdditionalOutputs(chatbot)
 
 
 
71
 
72
 
73
  chatbot = gr.Chatbot(type="messages")
 
1
  import json
2
  import os
3
+ import time
4
  from pathlib import Path
5
 
6
  import anthropic
 
14
  AdditionalOutputs,
15
  ReplyOnPause,
16
  Stream,
17
+ WebRTCError,
18
  get_tts_model,
19
  get_twilio_turn_credentials,
20
  )
 
38
  audio: tuple[int, np.ndarray],
39
  chatbot: list[dict] | None = None,
40
  ):
41
+ try:
42
+ chatbot = chatbot or []
43
+ messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
44
+ prompt = groq_client.audio.transcriptions.create(
45
+ file=("audio-file.mp3", audio_to_bytes(audio)),
46
+ model="whisper-large-v3-turbo",
47
+ response_format="verbose_json",
48
+ ).text
49
+
50
+ print("prompt", prompt)
51
+ chatbot.append({"role": "user", "content": prompt})
52
+ yield AdditionalOutputs(chatbot)
53
+ messages.append({"role": "user", "content": prompt})
54
+ response = claude_client.messages.create(
55
+ model="claude-3-5-haiku-20241022",
56
+ max_tokens=512,
57
+ messages=messages, # type: ignore
58
+ )
59
+ response_text = " ".join(
60
+ block.text # type: ignore
61
+ for block in response.content
62
+ if getattr(block, "type", None) == "text"
63
+ )
64
+ chatbot.append({"role": "assistant", "content": response_text})
65
+
66
+ start = time.time()
67
+
68
+ print("starting tts", start)
69
+ for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
70
+ print("chunk", i, time.time() - start)
71
+ yield chunk
72
+ print("finished tts", time.time() - start)
73
+ yield AdditionalOutputs(chatbot)
74
+ except Exception as e:
75
+ raise WebRTCError(str(e))
76
 
77
 
78
  chatbot = gr.Chatbot(type="messages")
index.html CHANGED
@@ -210,10 +210,28 @@
210
  transform: scale(1.2);
211
  }
212
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  </style>
214
  </head>
215
 
216
  <body>
 
 
217
  <div class="container">
218
  <div class="chat-container">
219
  <div class="chat-messages" id="chat-messages"></div>
@@ -270,6 +288,17 @@
270
  }
271
  }
272
 
 
 
 
 
 
 
 
 
 
 
 
273
  async function setupWebRTC() {
274
  const config = __RTC_CONFIGURATION__;
275
  peerConnection = new RTCPeerConnection(config);
@@ -329,7 +358,32 @@
329
 
330
  // Create data channel for messages
331
  const dataChannel = peerConnection.createDataChannel('text');
332
- dataChannel.onmessage = handleMessage;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  // Create and send offer
335
  const offer = await peerConnection.createOffer();
@@ -362,6 +416,15 @@
362
  });
363
 
364
  const serverResponse = await response.json();
 
 
 
 
 
 
 
 
 
365
  await peerConnection.setRemoteDescription(serverResponse);
366
 
367
  // Start visualization
@@ -375,31 +438,8 @@
375
  });
376
  } catch (err) {
377
  console.error('Error setting up WebRTC:', err);
378
- }
379
- }
380
-
381
- function handleMessage(event) {
382
- const eventJson = JSON.parse(event.data);
383
- const typingIndicator = document.getElementById('typing-indicator');
384
-
385
- if (eventJson.type === "send_input") {
386
- fetch('/input_hook', {
387
- method: 'POST',
388
- headers: {
389
- 'Content-Type': 'application/json',
390
- },
391
- body: JSON.stringify({
392
- webrtc_id: webrtc_id,
393
- chatbot: chatHistory
394
- })
395
- });
396
- } else if (eventJson.type === "log") {
397
- if (eventJson.data === "pause_detected") {
398
- typingIndicator.style.display = 'block';
399
- chatMessages.scrollTop = chatMessages.scrollHeight;
400
- } else if (eventJson.data === "response_starting") {
401
- typingIndicator.style.display = 'none';
402
- }
403
  }
404
  }
405
 
 
210
  transform: scale(1.2);
211
  }
212
  }
213
+
214
+ /* Add styles for toast notifications */
215
+ .toast {
216
+ position: fixed;
217
+ top: 20px;
218
+ left: 50%;
219
+ transform: translateX(-50%);
220
+ background-color: #f44336;
221
+ color: white;
222
+ padding: 16px 24px;
223
+ border-radius: 4px;
224
+ font-size: 14px;
225
+ z-index: 1000;
226
+ display: none;
227
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
228
+ }
229
  </style>
230
  </head>
231
 
232
  <body>
233
+ <!-- Add toast element after body opening tag -->
234
+ <div id="error-toast" class="toast"></div>
235
  <div class="container">
236
  <div class="chat-container">
237
  <div class="chat-messages" id="chat-messages"></div>
 
288
  }
289
  }
290
 
291
+ function showError(message) {
292
+ const toast = document.getElementById('error-toast');
293
+ toast.textContent = message;
294
+ toast.style.display = 'block';
295
+
296
+ // Hide toast after 5 seconds
297
+ setTimeout(() => {
298
+ toast.style.display = 'none';
299
+ }, 5000);
300
+ }
301
+
302
  async function setupWebRTC() {
303
  const config = __RTC_CONFIGURATION__;
304
  peerConnection = new RTCPeerConnection(config);
 
358
 
359
  // Create data channel for messages
360
  const dataChannel = peerConnection.createDataChannel('text');
361
+ dataChannel.onmessage = (event) => {
362
+ const eventJson = JSON.parse(event.data);
363
+ const typingIndicator = document.getElementById('typing-indicator');
364
+
365
+ if (eventJson.type === "error") {
366
+ showError(eventJson.message);
367
+ } else if (eventJson.type === "send_input") {
368
+ fetch('/input_hook', {
369
+ method: 'POST',
370
+ headers: {
371
+ 'Content-Type': 'application/json',
372
+ },
373
+ body: JSON.stringify({
374
+ webrtc_id: webrtc_id,
375
+ chatbot: chatHistory
376
+ })
377
+ });
378
+ } else if (eventJson.type === "log") {
379
+ if (eventJson.data === "pause_detected") {
380
+ typingIndicator.style.display = 'block';
381
+ chatMessages.scrollTop = chatMessages.scrollHeight;
382
+ } else if (eventJson.data === "response_starting") {
383
+ typingIndicator.style.display = 'none';
384
+ }
385
+ }
386
+ };
387
 
388
  // Create and send offer
389
  const offer = await peerConnection.createOffer();
 
416
  });
417
 
418
  const serverResponse = await response.json();
419
+
420
+ if (serverResponse.status === 'failed') {
421
+ showError(serverResponse.meta.error === 'concurrency_limit_reached'
422
+ ? `Too many connections. Maximum limit is ${serverResponse.meta.limit}`
423
+ : serverResponse.meta.error);
424
+ stop();
425
+ return;
426
+ }
427
+
428
  await peerConnection.setRemoteDescription(serverResponse);
429
 
430
  // Start visualization
 
438
  });
439
  } catch (err) {
440
  console.error('Error setting up WebRTC:', err);
441
+ showError('Failed to establish connection. Please try again.');
442
+ stop();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  }
444
  }
445