seawolf2357 commited on
Commit
868c0a3
ยท
verified ยท
1 Parent(s): c909595

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -30
app.py CHANGED
@@ -18,9 +18,15 @@ from fastrtc import (
18
  )
19
  from gradio.utils import get_space
20
  from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
 
 
21
 
22
  load_dotenv()
23
 
 
 
 
 
24
  cur_dir = Path(__file__).parent
25
 
26
  SAMPLE_RATE = 24000
@@ -49,44 +55,71 @@ class OpenAIHandler(AsyncStreamHandler):
49
  ):
50
  """Connect to realtime API. Run forever in separate thread to keep connection open."""
51
  self.client = openai.AsyncOpenAI()
52
- async with self.client.beta.realtime.connect(
53
- model="gpt-4o-mini-realtime-preview-2024-12-17"
54
- ) as conn:
55
- await conn.session.update(
56
- session={
57
- "turn_detection": {"type": "server_vad"},
58
- "system_prompt": self.system_prompt
59
- }
60
- )
61
- self.connection = conn
62
- async for event in self.connection:
63
- if event.type == "response.audio_transcript.done":
64
- await self.output_queue.put(AdditionalOutputs(event))
65
- if event.type == "response.audio.delta":
66
- await self.output_queue.put(
67
- (
68
- self.output_sample_rate,
69
- np.frombuffer(
70
- base64.b64decode(event.delta), dtype=np.int16
71
- ).reshape(1, -1),
72
- ),
73
- )
 
 
 
 
 
 
74
 
75
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
76
  if not self.connection:
 
77
  return
78
- _, array = frame
79
- array = array.squeeze()
80
- audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
81
- await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
84
- return await wait_for_item(self.output_queue)
 
 
 
 
85
 
86
  async def shutdown(self) -> None:
87
  if self.connection:
88
- await self.connection.close()
89
- self.connection = None
 
 
 
 
90
 
91
 
92
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
@@ -97,7 +130,7 @@ def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEve
97
  chatbot = gr.Chatbot(type="messages")
98
  latest_message = gr.Textbox(type="text", visible=False)
99
  stream = Stream(
100
- OpenAIHandler(system_prompt="๋‹น์‹ ์€ ์นœ์ ˆํ•œ ํ•œ๊ตญ์–ด AI ๋น„์„œ์ž…๋‹ˆ๋‹ค. ๋„ˆ์˜ ์ด๋ฆ„์€ '๋น„๋“œ๋ž˜ํ”„๏ฟฝ๏ฟฝ'์ž…๋‹ˆ๋‹ค. ๋ชจ๋“  ์งˆ๋ฌธ์— ํ•œ๊ตญ์–ด๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ๋ช…ํ™•ํ•˜๊ฒŒ, ํ•ญ์ƒ ์กด๋Œ“๋ง๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”."),
101
  mode="send-receive",
102
  modality="audio",
103
  additional_inputs=[chatbot],
 
18
  )
19
  from gradio.utils import get_space
20
  from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
21
+ import websockets.exceptions
22
+ import logging
23
 
24
  load_dotenv()
25
 
26
+ # ๋กœ๊น… ์„ค์ •
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
  cur_dir = Path(__file__).parent
31
 
32
  SAMPLE_RATE = 24000
 
55
  ):
56
  """Connect to realtime API. Run forever in separate thread to keep connection open."""
57
  self.client = openai.AsyncOpenAI()
58
+ try:
59
+ async with self.client.beta.realtime.connect(
60
+ model="gpt-4o-mini-realtime-preview-2024-12-17"
61
+ ) as conn:
62
+ await conn.session.update(
63
+ session={
64
+ "turn_detection": {"type": "server_vad"},
65
+ "system_prompt": self.system_prompt
66
+ }
67
+ )
68
+ self.connection = conn
69
+ async for event in self.connection:
70
+ if event.type == "response.audio_transcript.done":
71
+ await self.output_queue.put(AdditionalOutputs(event))
72
+ if event.type == "response.audio.delta":
73
+ await self.output_queue.put(
74
+ (
75
+ self.output_sample_rate,
76
+ np.frombuffer(
77
+ base64.b64decode(event.delta), dtype=np.int16
78
+ ).reshape(1, -1),
79
+ ),
80
+ )
81
+ except Exception as e:
82
+ logger.error(f"Error in start_up: {e}")
83
+ if self.connection:
84
+ await self.connection.close()
85
+ self.connection = None
86
 
87
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
88
  if not self.connection:
89
+ logger.warning("No connection available")
90
  return
91
+ try:
92
+ _, array = frame
93
+ array = array.squeeze()
94
+ audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
95
+ await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
96
+ except websockets.exceptions.ConnectionClosedOK:
97
+ logger.info("WebSocket connection closed normally")
98
+ # ์ •์ƒ์ ์ธ ์ข…๋ฃŒ๋Š” ๋ฌด์‹œํ•˜๊ณ  ๋„˜์–ด๊ฐ
99
+ except Exception as e:
100
+ logger.error(f"Error in receive: {e}")
101
+ if self.connection:
102
+ try:
103
+ await self.connection.close()
104
+ except:
105
+ pass
106
+ self.connection = None
107
 
108
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
109
+ try:
110
+ return await wait_for_item(self.output_queue)
111
+ except Exception as e:
112
+ logger.error(f"Error in emit: {e}")
113
+ return None
114
 
115
  async def shutdown(self) -> None:
116
  if self.connection:
117
+ try:
118
+ await self.connection.close()
119
+ except Exception as e:
120
+ logger.error(f"Error closing connection: {e}")
121
+ finally:
122
+ self.connection = None
123
 
124
 
125
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
 
130
  chatbot = gr.Chatbot(type="messages")
131
  latest_message = gr.Textbox(type="text", visible=False)
132
  stream = Stream(
133
+ OpenAIHandler(system_prompt="๋‹น์‹ ์€ ์นœ์ ˆํ•œ ํ•œ๊ตญ์–ด AI ๋น„์„œ '๋งˆ์šฐ์Šค'์ž…๋‹ˆ๋‹ค. ๋ชจ๋“  ์งˆ๋ฌธ์— ํ•œ๊ตญ์–ด๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ๋ช…ํ™•ํ•˜๊ฒŒ, ํ•ญ์ƒ ์กด๋Œ“๋ง๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”."),
134
  mode="send-receive",
135
  modality="audio",
136
  additional_inputs=[chatbot],