SalexAI commited on
Commit
60778ba
·
verified ·
1 Parent(s): e5956ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -25
app.py CHANGED
@@ -1,34 +1,61 @@
1
  import os
2
  import httpx
3
- from fastrtc import ReplyOnPause, Stream, get_stt_model, get_tts_model, StreamHandlerBase
 
 
 
4
  from openai import OpenAI
5
 
6
- # Initialize Sambanova Client
7
  sambanova_client = OpenAI(
8
- api_key=os.getenv("key"), base_url="https://api.deepinfra.com/v1"
 
9
  )
10
-
11
- # Load STT and TTS models
12
  stt_model = get_stt_model()
13
  tts_model = get_tts_model()
14
 
15
- # Create a proper handler subclass
16
- class EchoHandler(StreamHandlerBase):
17
  def __init__(self):
18
- super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def on_audio(self, audio):
21
- prompt = stt_model.stt(audio)
22
  response = sambanova_client.chat.completions.create(
23
  model="mistralai/Mistral-Small-24B-Instruct-2501",
24
- messages=[{"role": "user", "content": prompt}],
25
  max_tokens=200,
26
  )
27
  reply = response.choices[0].message.content
28
- for audio_chunk in tts_model.stream_tts_sync(reply):
29
- yield audio_chunk
30
 
31
- # Dummy TURN config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def get_cloudflare_turn_credentials(
33
  turn_key_id=None,
34
  turn_key_api_token=None,
@@ -36,20 +63,15 @@ def get_cloudflare_turn_credentials(
36
  ttl=600,
37
  client: httpx.AsyncClient | None = None,
38
  ):
39
- return {
40
- "iceServers": [
41
- {
42
- "urls": ["stun:stun.l.google.com:19302"]
43
- }
44
- ]
45
- }
46
-
47
- # Launch stream with correct handler
48
  stream = Stream(
49
  handler=EchoHandler(),
50
- rtc_configuration=get_cloudflare_turn_credentials,
51
  modality="audio",
52
- mode="send-receive"
 
53
  )
54
 
55
  stream.fastphone()
 
1
  import os
2
  import httpx
3
+ import numpy as np
4
+ from queue import Queue, Empty
5
+
6
+ from fastrtc import Stream, StreamHandler, get_stt_model, get_tts_model
7
  from openai import OpenAI
8
 
9
+ # Initialize OpenAI client and on-device models
10
  sambanova_client = OpenAI(
11
+ api_key=os.getenv("key"),
12
+ base_url="https://api.deepinfra.com/v1"
13
  )
 
 
14
  stt_model = get_stt_model()
15
  tts_model = get_tts_model()
16
 
17
+ class EchoHandler(StreamHandler):
 
18
  def __init__(self):
19
+ super().__init__() # uses default sample rates/layouts
20
+ self.queue: Queue[tuple[int, np.ndarray]] = Queue()
21
+
22
+ def start_up(self) -> None:
23
+ # Optional: warm up models or state here
24
+ pass
25
+
26
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
27
+ # frame is (sample_rate, numpy array)
28
+ sample_rate, audio_array = frame
29
+
30
+ # 1) Transcribe speech → text
31
+ text = stt_model.stt(frame)
32
 
33
+ # 2) Chat completion
 
34
  response = sambanova_client.chat.completions.create(
35
  model="mistralai/Mistral-Small-24B-Instruct-2501",
36
+ messages=[{"role": "user", "content": text}],
37
  max_tokens=200,
38
  )
39
  reply = response.choices[0].message.content
 
 
40
 
41
+ # 3) Generate TTS chunks and enqueue them
42
+ for tts_chunk in tts_model.stream_tts_sync(reply):
43
+ # each tts_chunk is a numpy array of shape (1, N)
44
+ self.queue.put((sample_rate, tts_chunk))
45
+
46
+ def emit(self):
47
+ try:
48
+ return self.queue.get_nowait()
49
+ except Empty:
50
+ return None # no audio to send right now
51
+
52
+ def copy(self) -> "EchoHandler":
53
+ return EchoHandler()
54
+
55
+ def shutdown(self) -> None:
56
+ # Optional cleanup
57
+ pass
58
+
59
  def get_cloudflare_turn_credentials(
60
  turn_key_id=None,
61
  turn_key_api_token=None,
 
63
  ttl=600,
64
  client: httpx.AsyncClient | None = None,
65
  ):
66
+ # Replace with your real TURN creds logic
67
+ return {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
68
+
69
+ # Wire up the stream with the new handler
 
 
 
 
 
70
  stream = Stream(
71
  handler=EchoHandler(),
 
72
  modality="audio",
73
+ mode="send-receive",
74
+ rtc_configuration=get_cloudflare_turn_credentials
75
  )
76
 
77
  stream.fastphone()