Spaces:
Running
Running
Commit
Β·
1dd5469
1
Parent(s):
3a4cb0f
ui.py
CHANGED
@@ -8,334 +8,523 @@ from typing import Dict, Any, Optional
|
|
8 |
import threading
|
9 |
from queue import Queue
|
10 |
import base64
|
|
|
|
|
11 |
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def __init__(self):
|
20 |
self.connected_clients = set()
|
21 |
-
self.message_queue = Queue()
|
22 |
self.is_running = False
|
23 |
self.websocket_server = None
|
24 |
-
self.current_transcript = ""
|
25 |
self.conversation_history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
async def
|
28 |
-
"""Handle WebSocket
|
29 |
-
|
30 |
-
|
31 |
|
32 |
-
|
33 |
|
34 |
try:
|
35 |
-
# Send connection
|
36 |
await websocket.send(json.dumps({
|
37 |
-
"type": "
|
38 |
"status": "connected",
|
39 |
"timestamp": time.time(),
|
40 |
-
"
|
41 |
}))
|
42 |
|
|
|
43 |
async for message in websocket:
|
44 |
try:
|
45 |
if isinstance(message, bytes):
|
46 |
# Handle binary audio data
|
47 |
-
await self.
|
48 |
else:
|
49 |
-
# Handle text messages
|
50 |
-
|
51 |
-
await self.handle_message(data, websocket)
|
52 |
|
53 |
-
except json.JSONDecodeError:
|
54 |
-
logger.warning(f"Invalid JSON received from client: {message}")
|
55 |
except Exception as e:
|
56 |
logger.error(f"Error processing message: {e}")
|
|
|
57 |
|
58 |
except websockets.exceptions.ConnectionClosed:
|
59 |
-
logger.info(
|
60 |
except Exception as e:
|
61 |
-
logger.error(f"Client
|
62 |
finally:
|
63 |
self.connected_clients.discard(websocket)
|
64 |
-
logger.info(f"Client removed.
|
65 |
-
|
66 |
-
async def
|
67 |
-
"""Process incoming audio data"""
|
68 |
try:
|
69 |
-
|
70 |
-
|
71 |
|
72 |
-
|
73 |
-
result = await process_audio_for_transcription(audio_data)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
|
88 |
-
|
89 |
except Exception as e:
|
90 |
-
logger.error(f"
|
91 |
-
await
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
"""Handle non-audio messages from clients"""
|
99 |
-
message_type = data.get("type", "unknown")
|
100 |
-
|
101 |
-
if message_type == "config":
|
102 |
-
# Handle configuration updates
|
103 |
-
logger.info(f"Configuration update: {data}")
|
104 |
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
await websocket.send(json.dumps({
|
108 |
-
"type": "
|
109 |
-
"
|
110 |
"timestamp": time.time()
|
111 |
}))
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
else:
|
124 |
-
logger.warning(f"Unknown message type: {message_type}")
|
125 |
-
|
126 |
-
async def broadcast_result(self, result: Dict[str, Any]):
|
127 |
-
"""Broadcast results to all connected clients"""
|
128 |
-
if not self.connected_clients:
|
129 |
-
return
|
130 |
-
|
131 |
-
message = json.dumps(result)
|
132 |
-
disconnected = set()
|
133 |
|
134 |
-
|
135 |
-
try:
|
136 |
-
await client.send(message)
|
137 |
-
except Exception as e:
|
138 |
-
logger.warning(f"Failed to send to client: {e}")
|
139 |
-
disconnected.add(client)
|
140 |
|
141 |
-
#
|
142 |
-
|
143 |
-
self.
|
144 |
-
|
145 |
-
def
|
146 |
-
"""
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
if len(self.conversation_history) > 100:
|
159 |
-
self.conversation_history = self.conversation_history[-100:]
|
160 |
-
|
161 |
-
async def start_websocket_server(self, host="0.0.0.0", port=7860):
|
162 |
"""Start the WebSocket server"""
|
163 |
try:
|
|
|
164 |
self.websocket_server = await websockets.serve(
|
165 |
-
self.
|
166 |
host,
|
167 |
port,
|
|
|
168 |
path="/ws_inference"
|
169 |
)
|
|
|
170 |
self.is_running = True
|
171 |
-
logger.info(f"WebSocket server started on {host}:{port}")
|
172 |
|
173 |
-
# Keep server running
|
174 |
await self.websocket_server.wait_closed()
|
175 |
|
176 |
except Exception as e:
|
177 |
-
logger.error(f"WebSocket server
|
178 |
self.is_running = False
|
179 |
-
|
180 |
-
def get_status(self):
|
181 |
-
"""Get current status information"""
|
182 |
-
return {
|
183 |
-
"connected_clients": len(self.connected_clients),
|
184 |
-
"is_running": self.is_running,
|
185 |
-
"conversation_entries": len(self.conversation_history),
|
186 |
-
"last_activity": time.time()
|
187 |
-
}
|
188 |
|
189 |
-
# Initialize the
|
190 |
-
|
191 |
|
192 |
def create_gradio_interface():
|
193 |
-
"""Create
|
194 |
|
195 |
def get_server_status():
|
196 |
-
"""Get server status
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
"""
|
|
|
|
|
205 |
|
206 |
-
def
|
207 |
-
"""Get
|
208 |
-
if not
|
209 |
-
return "No
|
|
|
|
|
210 |
|
211 |
-
|
212 |
-
for entry in
|
213 |
-
timestamp = time.
|
214 |
-
speaker = entry
|
215 |
-
text = entry
|
216 |
-
confidence = entry
|
217 |
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
return
|
221 |
|
222 |
-
def
|
223 |
"""Clear conversation history"""
|
224 |
-
|
225 |
-
|
226 |
-
return "Conversation history cleared."
|
227 |
|
228 |
# Create Gradio interface
|
229 |
-
with gr.Blocks(
|
230 |
-
|
231 |
-
gr.
|
|
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
refresh_btn = gr.Button("Refresh Status")
|
236 |
-
refresh_btn.click(get_server_status, outputs=status_display)
|
237 |
|
238 |
-
with gr.Tab("
|
239 |
-
gr.Markdown(
|
240 |
-
conversation_display = gr.Markdown(get_conversation_history())
|
241 |
|
242 |
with gr.Row():
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
248 |
|
249 |
-
with gr.Tab("
|
250 |
-
gr.Markdown(
|
251 |
-
### WebSocket Endpoint
|
252 |
-
Connect to this Space's WebSocket endpoint for real-time audio processing:
|
253 |
-
|
254 |
-
**WebSocket URL:** `wss://your-space-name.hf.space/ws_inference`
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
"settings": {
|
265 |
-
"language": "en",
|
266 |
-
"enable_diarization": true
|
267 |
-
}
|
268 |
-
}
|
269 |
-
```
|
270 |
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
""")
|
284 |
|
285 |
-
with gr.Tab("API Documentation"):
|
286 |
gr.Markdown("""
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
""")
|
314 |
|
315 |
return demo
|
316 |
|
317 |
def run_websocket_server():
|
318 |
-
"""Run
|
319 |
loop = asyncio.new_event_loop()
|
320 |
asyncio.set_event_loop(loop)
|
321 |
|
322 |
try:
|
323 |
-
|
|
|
324 |
except Exception as e:
|
325 |
-
logger.error(f"WebSocket server
|
326 |
finally:
|
327 |
loop.close()
|
328 |
|
329 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
websocket_thread = threading.Thread(target=run_websocket_server, daemon=True)
|
331 |
websocket_thread.start()
|
332 |
|
|
|
|
|
|
|
333 |
# Create and launch Gradio interface
|
334 |
if __name__ == "__main__":
|
335 |
demo = create_gradio_interface()
|
336 |
demo.launch(
|
337 |
server_name="0.0.0.0",
|
338 |
server_port=7860,
|
339 |
-
share=
|
340 |
show_error=True
|
341 |
)
|
|
|
8 |
import threading
|
9 |
from queue import Queue
|
10 |
import base64
|
11 |
+
import numpy as np
|
12 |
+
import os
|
13 |
|
14 |
# Configure logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
+
# Environment-configurable HF Space URL (matching backend.py)
|
19 |
+
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://androidguy-speaker-diarization.hf.space")
|
20 |
+
API_WS = f"wss://{HF_SPACE_URL}/ws_inference"
|
21 |
+
|
22 |
+
class TranscriptionWebSocketServer:
|
23 |
+
"""WebSocket server that receives audio from backend and returns transcription results"""
|
24 |
|
25 |
def __init__(self):
|
26 |
self.connected_clients = set()
|
|
|
27 |
self.is_running = False
|
28 |
self.websocket_server = None
|
|
|
29 |
self.conversation_history = []
|
30 |
+
self.processing_stats = {
|
31 |
+
"total_audio_chunks": 0,
|
32 |
+
"total_transcriptions": 0,
|
33 |
+
"last_audio_received": None,
|
34 |
+
"server_start_time": time.time(),
|
35 |
+
"backend_url": HF_SPACE_URL
|
36 |
+
}
|
37 |
|
38 |
+
async def handle_client_connection(self, websocket, path):
|
39 |
+
"""Handle incoming WebSocket connections from the backend"""
|
40 |
+
client_addr = websocket.remote_address
|
41 |
+
logger.info(f"Backend client connected from {client_addr}")
|
42 |
|
43 |
+
self.connected_clients.add(websocket)
|
44 |
|
45 |
try:
|
46 |
+
# Send initial connection acknowledgment
|
47 |
await websocket.send(json.dumps({
|
48 |
+
"type": "connection_ack",
|
49 |
"status": "connected",
|
50 |
"timestamp": time.time(),
|
51 |
+
"message": "HuggingFace transcription service ready"
|
52 |
}))
|
53 |
|
54 |
+
# Handle incoming messages/audio data
|
55 |
async for message in websocket:
|
56 |
try:
|
57 |
if isinstance(message, bytes):
|
58 |
# Handle binary audio data
|
59 |
+
await self.process_audio_data(message, websocket)
|
60 |
else:
|
61 |
+
# Handle text messages (JSON)
|
62 |
+
await self.handle_text_message(message, websocket)
|
|
|
63 |
|
|
|
|
|
64 |
except Exception as e:
|
65 |
logger.error(f"Error processing message: {e}")
|
66 |
+
await self.send_error(websocket, f"Processing error: {str(e)}")
|
67 |
|
68 |
except websockets.exceptions.ConnectionClosed:
|
69 |
+
logger.info("Backend client disconnected")
|
70 |
except Exception as e:
|
71 |
+
logger.error(f"Client connection error: {e}")
|
72 |
finally:
|
73 |
self.connected_clients.discard(websocket)
|
74 |
+
logger.info(f"Client removed. Active connections: {len(self.connected_clients)}")
|
75 |
+
|
76 |
+
async def process_audio_data(self, audio_data: bytes, websocket):
|
77 |
+
"""Process incoming audio data and return transcription results"""
|
78 |
try:
|
79 |
+
self.processing_stats["total_audio_chunks"] += 1
|
80 |
+
self.processing_stats["last_audio_received"] = time.time()
|
81 |
|
82 |
+
logger.debug(f"Received {len(audio_data)} bytes of audio data")
|
|
|
83 |
|
84 |
+
# Try to import and use your inference functions
|
85 |
+
try:
|
86 |
+
from inference import transcribe_audio, identify_speakers
|
87 |
+
|
88 |
+
# Process the audio for transcription
|
89 |
+
transcription_result = await transcribe_audio(audio_data)
|
90 |
+
|
91 |
+
if transcription_result:
|
92 |
+
# Process for speaker diarization if available
|
93 |
+
try:
|
94 |
+
speaker_info = await identify_speakers(audio_data)
|
95 |
+
transcription_result.update(speaker_info)
|
96 |
+
except Exception as e:
|
97 |
+
logger.warning(f"Speaker diarization failed: {e}")
|
98 |
+
transcription_result["speaker"] = "Unknown"
|
99 |
+
|
100 |
+
# Update conversation history
|
101 |
+
self.update_conversation_history(transcription_result)
|
102 |
+
|
103 |
+
# Send result back to backend
|
104 |
+
response = {
|
105 |
+
"type": "processing_result",
|
106 |
+
"timestamp": time.time(),
|
107 |
+
"data": transcription_result
|
108 |
+
}
|
109 |
+
|
110 |
+
await websocket.send(json.dumps(response))
|
111 |
+
self.processing_stats["total_transcriptions"] += 1
|
112 |
+
|
113 |
+
logger.info(f"Sent transcription result: {transcription_result.get('text', '')[:50]}...")
|
114 |
+
|
115 |
+
except ImportError:
|
116 |
+
# Fallback if inference module is not available
|
117 |
+
logger.warning("Inference module not found, using mock transcription")
|
118 |
+
|
119 |
+
# Try to use shared.py for processing if available
|
120 |
+
try:
|
121 |
+
from shared import RealtimeSpeakerDiarization
|
122 |
+
|
123 |
+
# Initialize if not already initialized
|
124 |
+
if not hasattr(self, 'diarization_system'):
|
125 |
+
self.diarization_system = RealtimeSpeakerDiarization()
|
126 |
+
await asyncio.to_thread(self.diarization_system.initialize_models)
|
127 |
+
await asyncio.to_thread(self.diarization_system.start_recording)
|
128 |
+
|
129 |
+
# Process the audio chunk
|
130 |
+
result = await asyncio.to_thread(self.diarization_system.process_audio_chunk, audio_data)
|
131 |
+
|
132 |
+
# Format result for response
|
133 |
+
if result and result["status"] != "error":
|
134 |
+
mock_result = {
|
135 |
+
"text": result.get("text", f"[Processing {len(audio_data)} bytes]"),
|
136 |
+
"speaker": f"Speaker_{result.get('speaker_id', 0) + 1}",
|
137 |
+
"confidence": result.get("similarity", 0.85),
|
138 |
+
"timestamp": time.time()
|
139 |
+
}
|
140 |
+
else:
|
141 |
+
# Fallback mock result
|
142 |
+
mock_result = {
|
143 |
+
"text": f"[Mock transcription - {len(audio_data)} bytes processed]",
|
144 |
+
"speaker": "Speaker_1",
|
145 |
+
"confidence": 0.85,
|
146 |
+
"timestamp": time.time()
|
147 |
+
}
|
148 |
+
|
149 |
+
# Update conversation history
|
150 |
+
self.update_conversation_history(mock_result)
|
151 |
+
|
152 |
+
response = {
|
153 |
+
"type": "processing_result",
|
154 |
+
"timestamp": time.time(),
|
155 |
+
"data": mock_result
|
156 |
+
}
|
157 |
+
|
158 |
+
await websocket.send(json.dumps(response))
|
159 |
+
self.processing_stats["total_transcriptions"] += 1
|
160 |
|
161 |
+
except Exception as e:
|
162 |
+
logger.warning(f"Failed to use shared module: {e}")
|
163 |
+
|
164 |
+
# Basic mock transcription as last resort
|
165 |
+
mock_result = {
|
166 |
+
"text": f"[Mock transcription - {len(audio_data)} bytes processed]",
|
167 |
+
"speaker": "Speaker_1",
|
168 |
+
"confidence": 0.85,
|
169 |
+
"timestamp": time.time()
|
170 |
+
}
|
171 |
+
|
172 |
+
self.update_conversation_history(mock_result)
|
173 |
+
|
174 |
+
response = {
|
175 |
+
"type": "processing_result",
|
176 |
+
"timestamp": time.time(),
|
177 |
+
"data": mock_result
|
178 |
+
}
|
179 |
|
180 |
+
await websocket.send(json.dumps(response))
|
181 |
+
|
182 |
except Exception as e:
|
183 |
+
logger.error(f"Audio processing error: {e}")
|
184 |
+
await self.send_error(websocket, f"Audio processing failed: {str(e)}")
|
185 |
+
|
186 |
+
async def handle_text_message(self, message: str, websocket):
|
187 |
+
"""Handle text-based messages from backend"""
|
188 |
+
try:
|
189 |
+
data = json.loads(message)
|
190 |
+
message_type = data.get("type", "unknown")
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
+
logger.info(f"Received message type: {message_type}")
|
193 |
+
|
194 |
+
if message_type == "ping":
|
195 |
+
# Respond to ping with pong
|
196 |
+
await websocket.send(json.dumps({
|
197 |
+
"type": "pong",
|
198 |
+
"timestamp": time.time()
|
199 |
+
}))
|
200 |
+
|
201 |
+
elif message_type == "config":
|
202 |
+
# Handle configuration updates
|
203 |
+
logger.info(f"Configuration update: {data}")
|
204 |
+
|
205 |
+
# Apply configuration settings if available
|
206 |
+
settings = data.get("settings", {})
|
207 |
+
if "max_speakers" in settings:
|
208 |
+
max_speakers = settings.get("max_speakers")
|
209 |
+
logger.info(f"Setting max_speakers to {max_speakers}")
|
210 |
+
|
211 |
+
if "threshold" in settings:
|
212 |
+
threshold = settings.get("threshold")
|
213 |
+
logger.info(f"Setting speaker change threshold to {threshold}")
|
214 |
+
|
215 |
+
# Send acknowledgment
|
216 |
+
await websocket.send(json.dumps({
|
217 |
+
"type": "config_ack",
|
218 |
+
"message": "Configuration received",
|
219 |
+
"timestamp": time.time()
|
220 |
+
}))
|
221 |
+
|
222 |
+
elif message_type == "status_request":
|
223 |
+
# Send status information
|
224 |
+
await websocket.send(json.dumps({
|
225 |
+
"type": "status_response",
|
226 |
+
"data": self.get_processing_stats(),
|
227 |
+
"timestamp": time.time()
|
228 |
+
}))
|
229 |
+
|
230 |
+
else:
|
231 |
+
logger.warning(f"Unknown message type: {message_type}")
|
232 |
+
|
233 |
+
except json.JSONDecodeError:
|
234 |
+
logger.error(f"Invalid JSON received: {message}")
|
235 |
+
await self.send_error(websocket, "Invalid JSON format")
|
236 |
+
|
237 |
+
async def send_error(self, websocket, error_message: str):
|
238 |
+
"""Send error message to client"""
|
239 |
+
try:
|
240 |
await websocket.send(json.dumps({
|
241 |
+
"type": "error",
|
242 |
+
"message": error_message,
|
243 |
"timestamp": time.time()
|
244 |
}))
|
245 |
+
except Exception as e:
|
246 |
+
logger.error(f"Failed to send error message: {e}")
|
247 |
+
|
248 |
+
def update_conversation_history(self, transcription_result: Dict[str, Any]):
|
249 |
+
"""Update conversation history with new transcription"""
|
250 |
+
history_entry = {
|
251 |
+
"timestamp": time.time(),
|
252 |
+
"text": transcription_result.get("text", ""),
|
253 |
+
"speaker": transcription_result.get("speaker", "Unknown"),
|
254 |
+
"confidence": transcription_result.get("confidence", 0.0)
|
255 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
+
self.conversation_history.append(history_entry)
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
+
# Keep only last 50 entries to prevent memory issues
|
260 |
+
if len(self.conversation_history) > 50:
|
261 |
+
self.conversation_history = self.conversation_history[-50:]
|
262 |
+
|
263 |
+
def get_processing_stats(self):
|
264 |
+
"""Get processing statistics"""
|
265 |
+
return {
|
266 |
+
"connected_clients": len(self.connected_clients),
|
267 |
+
"total_audio_chunks": self.processing_stats["total_audio_chunks"],
|
268 |
+
"total_transcriptions": self.processing_stats["total_transcriptions"],
|
269 |
+
"last_audio_received": self.processing_stats["last_audio_received"],
|
270 |
+
"server_uptime": time.time() - self.processing_stats["server_start_time"],
|
271 |
+
"conversation_entries": len(self.conversation_history),
|
272 |
+
"backend_url": self.processing_stats.get("backend_url", HF_SPACE_URL)
|
273 |
+
}
|
274 |
+
|
275 |
+
async def start_server(self, host="0.0.0.0", port=7860):
|
|
|
|
|
|
|
|
|
276 |
"""Start the WebSocket server"""
|
277 |
try:
|
278 |
+
# Start WebSocket server on /ws_inference endpoint
|
279 |
self.websocket_server = await websockets.serve(
|
280 |
+
self.handle_client_connection,
|
281 |
host,
|
282 |
port,
|
283 |
+
subprotocols=[],
|
284 |
path="/ws_inference"
|
285 |
)
|
286 |
+
|
287 |
self.is_running = True
|
288 |
+
logger.info(f"WebSocket server started on ws://{host}:{port}/ws_inference")
|
289 |
|
290 |
+
# Keep the server running
|
291 |
await self.websocket_server.wait_closed()
|
292 |
|
293 |
except Exception as e:
|
294 |
+
logger.error(f"Failed to start WebSocket server: {e}")
|
295 |
self.is_running = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
+
# Initialize the WebSocket server
|
298 |
+
ws_server = TranscriptionWebSocketServer()
|
299 |
|
300 |
def create_gradio_interface():
|
301 |
+
"""Create Gradio interface for monitoring and testing"""
|
302 |
|
303 |
def get_server_status():
|
304 |
+
"""Get current server status"""
|
305 |
+
stats = ws_server.get_processing_stats()
|
306 |
+
|
307 |
+
status_text = f"""
|
308 |
+
### Server Status
|
309 |
+
- **WebSocket Server**: {'π’ Running' if ws_server.is_running else 'π΄ Stopped'}
|
310 |
+
- **Connected Clients**: {stats['connected_clients']}
|
311 |
+
- **Server Uptime**: {stats['server_uptime']:.1f} seconds
|
312 |
+
|
313 |
+
### Processing Statistics
|
314 |
+
- **Audio Chunks Processed**: {stats['total_audio_chunks']}
|
315 |
+
- **Transcriptions Generated**: {stats['total_transcriptions']}
|
316 |
+
- **Last Audio Received**: {time.ctime(stats['last_audio_received']) if stats['last_audio_received'] else 'Never'}
|
317 |
+
|
318 |
+
### Conversation
|
319 |
+
- **History Entries**: {stats['conversation_entries']}
|
320 |
"""
|
321 |
+
|
322 |
+
return status_text
|
323 |
|
324 |
+
def get_recent_transcriptions():
|
325 |
+
"""Get recent transcription results"""
|
326 |
+
if not ws_server.conversation_history:
|
327 |
+
return "No transcriptions yet. Waiting for audio data from backend..."
|
328 |
+
|
329 |
+
recent_entries = ws_server.conversation_history[-10:] # Last 10 entries
|
330 |
|
331 |
+
formatted_text = "### Recent Transcriptions\n\n"
|
332 |
+
for entry in recent_entries:
|
333 |
+
timestamp = time.strftime("%H:%M:%S", time.localtime(entry['timestamp']))
|
334 |
+
speaker = entry['speaker']
|
335 |
+
text = entry['text']
|
336 |
+
confidence = entry['confidence']
|
337 |
|
338 |
+
# Extract speaker number for color matching with shared.py
|
339 |
+
speaker_num = 0
|
340 |
+
if speaker.startswith("Speaker_"):
|
341 |
+
try:
|
342 |
+
speaker_num = int(speaker.split("_")[1]) - 1
|
343 |
+
except (ValueError, IndexError):
|
344 |
+
speaker_num = 0
|
345 |
+
|
346 |
+
# Use colors from shared.py if possible
|
347 |
+
try:
|
348 |
+
from shared import SPEAKER_COLORS
|
349 |
+
color = SPEAKER_COLORS[speaker_num % len(SPEAKER_COLORS)]
|
350 |
+
except (ImportError, IndexError):
|
351 |
+
# Fallback colors
|
352 |
+
colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F"]
|
353 |
+
color = colors[speaker_num % len(colors)]
|
354 |
+
|
355 |
+
formatted_text += f"<span style='color:{color};font-weight:bold;'>[{timestamp}] {speaker}</span> (confidence: {confidence:.2f})\n"
|
356 |
+
formatted_text += f"{text}\n\n"
|
357 |
|
358 |
+
return formatted_text
|
359 |
|
360 |
+
def clear_conversation_history():
|
361 |
"""Clear conversation history"""
|
362 |
+
ws_server.conversation_history.clear()
|
363 |
+
return "Conversation history cleared!"
|
|
|
364 |
|
365 |
# Create Gradio interface
|
366 |
+
with gr.Blocks(
|
367 |
+
title="Real-time Audio Transcription Service",
|
368 |
+
theme=gr.themes.Soft()
|
369 |
+
) as demo:
|
370 |
|
371 |
+
gr.Markdown("# π€ Real-time Audio Transcription Service")
|
372 |
+
gr.Markdown("This HuggingFace Space receives audio from your backend and returns transcription results with speaker diarization.")
|
|
|
|
|
373 |
|
374 |
+
with gr.Tab("π Server Status"):
|
375 |
+
status_display = gr.Markdown(get_server_status())
|
|
|
376 |
|
377 |
with gr.Row():
|
378 |
+
refresh_status_btn = gr.Button("π Refresh Status", variant="primary")
|
379 |
+
|
380 |
+
refresh_status_btn.click(
|
381 |
+
fn=get_server_status,
|
382 |
+
outputs=status_display,
|
383 |
+
every=None
|
384 |
+
)
|
385 |
|
386 |
+
with gr.Tab("π Live Transcription"):
|
387 |
+
transcription_display = gr.Markdown(get_recent_transcriptions())
|
|
|
|
|
|
|
|
|
388 |
|
389 |
+
with gr.Row():
|
390 |
+
refresh_transcription_btn = gr.Button("π Refresh Transcriptions", variant="primary")
|
391 |
+
clear_history_btn = gr.Button("ποΈ Clear History", variant="secondary")
|
392 |
|
393 |
+
refresh_transcription_btn.click(
|
394 |
+
fn=get_recent_transcriptions,
|
395 |
+
outputs=transcription_display
|
396 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
+
clear_history_btn.click(
|
399 |
+
fn=clear_conversation_history,
|
400 |
+
outputs=gr.Markdown()
|
401 |
+
)
|
402 |
+
|
403 |
+
with gr.Tab("π§ Connection Info"):
|
404 |
+
gr.Markdown(f"""
|
405 |
+
### WebSocket Connection Details
|
406 |
+
|
407 |
+
**WebSocket Endpoint**: `wss://{HF_SPACE_URL}/ws_inference`
|
408 |
+
|
409 |
+
### Backend Connection
|
410 |
+
Your backend should connect to this WebSocket endpoint and:
|
411 |
+
|
412 |
+
1. **Send Audio Data**: Stream raw audio bytes to this endpoint
|
413 |
+
2. **Receive Results**: Get JSON responses with transcription results
|
414 |
+
|
415 |
+
### Expected Message Flow
|
416 |
+
|
417 |
+
**Backend β HuggingFace**:
|
418 |
+
- Raw audio bytes (binary data)
|
419 |
+
- Configuration messages (JSON)
|
420 |
+
|
421 |
+
**HuggingFace β Backend**:
|
422 |
+
```json
|
423 |
+
{{
|
424 |
+
"type": "processing_result",
|
425 |
+
"timestamp": 1234567890.123,
|
426 |
+
"data": {{
|
427 |
+
"text": "transcribed text here",
|
428 |
+
"speaker": "Speaker_1",
|
429 |
+
"confidence": 0.95
|
430 |
+
}}
|
431 |
+
}}
|
432 |
+
```
|
433 |
+
|
434 |
+
### Test Connection
|
435 |
+
Your backend is configured to connect to: `{ws_server.processing_stats.get('backend_url', HF_SPACE_URL)}`
|
436 |
""")
|
437 |
|
438 |
+
with gr.Tab("π API Documentation"):
|
439 |
gr.Markdown("""
|
440 |
+
### WebSocket API Reference
|
441 |
+
|
442 |
+
#### Endpoint
|
443 |
+
- **URL**: `/ws_inference`
|
444 |
+
- **Protocol**: WebSocket
|
445 |
+
- **Accepts**: Binary audio data + JSON messages
|
446 |
+
|
447 |
+
#### Message Types
|
448 |
+
|
449 |
+
##### 1. Audio Processing
|
450 |
+
- **Input**: Raw audio bytes (binary)
|
451 |
+
- **Output**: Processing result (JSON)
|
452 |
+
|
453 |
+
##### 2. Configuration
|
454 |
+
- **Input**:
|
455 |
+
```json
|
456 |
+
{
|
457 |
+
"type": "config",
|
458 |
+
"settings": {
|
459 |
+
"language": "en",
|
460 |
+
"enable_diarization": true,
|
461 |
+
"max_speakers": 4,
|
462 |
+
"threshold": 0.65
|
463 |
+
}
|
464 |
+
}
|
465 |
+
```
|
466 |
+
|
467 |
+
##### 3. Status Check
|
468 |
+
- **Input**: `{"type": "status_request"}`
|
469 |
+
- **Output**: Server statistics
|
470 |
+
|
471 |
+
##### 4. Ping/Pong
|
472 |
+
- **Input**: `{"type": "ping"}`
|
473 |
+
- **Output**: `{"type": "pong", "timestamp": 1234567890}`
|
474 |
+
|
475 |
+
#### Error Handling
|
476 |
+
All errors are returned as:
|
477 |
+
```json
|
478 |
+
{
|
479 |
+
"type": "error",
|
480 |
+
"message": "Error description",
|
481 |
+
"timestamp": 1234567890.123
|
482 |
+
}
|
483 |
+
```
|
484 |
""")
|
485 |
|
486 |
return demo
|
487 |
|
488 |
def run_websocket_server():
|
489 |
+
"""Run WebSocket server in background thread"""
|
490 |
loop = asyncio.new_event_loop()
|
491 |
asyncio.set_event_loop(loop)
|
492 |
|
493 |
try:
|
494 |
+
logger.info("Starting WebSocket server thread...")
|
495 |
+
loop.run_until_complete(ws_server.start_server())
|
496 |
except Exception as e:
|
497 |
+
logger.error(f"WebSocket server error: {e}")
|
498 |
finally:
|
499 |
loop.close()
|
500 |
|
501 |
+
# Mount UI to inference.py
|
502 |
+
def mount_ui(app):
|
503 |
+
"""Mount Gradio interface to FastAPI app"""
|
504 |
+
try:
|
505 |
+
demo = create_gradio_interface()
|
506 |
+
# Mount without starting server (FastAPI will handle it)
|
507 |
+
demo.mount_to_app(app)
|
508 |
+
logger.info("Gradio UI mounted to FastAPI app")
|
509 |
+
return True
|
510 |
+
except Exception as e:
|
511 |
+
logger.error(f"Error mounting UI: {e}")
|
512 |
+
return False
|
513 |
+
|
514 |
+
# Start WebSocket server in background
|
515 |
+
logger.info("Initializing WebSocket server...")
|
516 |
websocket_thread = threading.Thread(target=run_websocket_server, daemon=True)
|
517 |
websocket_thread.start()
|
518 |
|
519 |
+
# Give server time to start
|
520 |
+
time.sleep(2)
|
521 |
+
|
522 |
# Create and launch Gradio interface
|
523 |
if __name__ == "__main__":
|
524 |
demo = create_gradio_interface()
|
525 |
demo.launch(
|
526 |
server_name="0.0.0.0",
|
527 |
server_port=7860,
|
528 |
+
share=True,
|
529 |
show_error=True
|
530 |
)
|