isididiidid commited on
Commit
abf4faa
·
verified ·
1 Parent(s): b7975ba

Create ayy.py

Browse files
Files changed (1) hide show
  1. ayy.py +635 -0
ayy.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import random
5
+ import time
6
+ import os
7
+ from datetime import datetime
8
+ from typing import Dict, List, Optional, Set, Tuple, Any
9
+
10
+ from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect, HTTPException
11
+ from fastapi.responses import StreamingResponse, HTMLResponse
12
+ from fastapi.staticfiles import StaticFiles
13
+ import uvicorn
14
+ import httpx
15
+
16
+
17
+ # 日志记录器模块
18
+ class LoggingService:
19
+ def __init__(self, service_name: str = "ProxyServer"):
20
+ self.service_name = service_name
21
+ self.logger = logging.getLogger(service_name)
22
+ self.logger.setLevel(logging.DEBUG)
23
+
24
+ # 配置日志格式
25
+ formatter = logging.Formatter(
26
+ "[%(levelname)s] %(asctime)s [%(name)s] - %(message)s"
27
+ )
28
+
29
+ # 添加控制台处理器
30
+ ch = logging.StreamHandler()
31
+ ch.setFormatter(formatter)
32
+ self.logger.addHandler(ch)
33
+
34
+ def _format_message(self, level: str, message: str) -> str:
35
+ timestamp = datetime.now().isoformat()
36
+ return f"[{level}] {timestamp} [{self.service_name}] - {message}"
37
+
38
+ def info(self, message: str):
39
+ self.logger.info(message)
40
+
41
+ def error(self, message: str):
42
+ self.logger.error(message)
43
+
44
+ def warn(self, message: str):
45
+ self.logger.warning(message)
46
+
47
+ def debug(self, message: str):
48
+ self.logger.debug(message)
49
+
50
+
51
+ # 消息队列实现
52
+ class MessageQueue:
53
+ def __init__(self, timeout_ms: int = 600000):
54
+ self.messages: List[dict] = []
55
+ self.waiting_resolvers: List[Tuple[asyncio.Future, asyncio.TimerHandle]] = []
56
+ self.default_timeout = timeout_ms / 1000 # 转换为秒
57
+ self.closed = False
58
+
59
+ def enqueue(self, message: dict):
60
+ if self.closed:
61
+ return
62
+
63
+ if self.waiting_resolvers:
64
+ future, timer = self.waiting_resolvers.pop(0)
65
+ timer.cancel()
66
+ if not future.done():
67
+ future.set_result(message)
68
+ else:
69
+ self.messages.append(message)
70
+
71
+ async def dequeue(self, timeout_ms: Optional[int] = None) -> dict:
72
+ if self.closed:
73
+ raise Exception("Queue is closed")
74
+
75
+ if self.messages:
76
+ return self.messages.pop(0)
77
+
78
+ timeout = self.default_timeout if timeout_ms is None else timeout_ms / 1000
79
+ loop = asyncio.get_running_loop()
80
+ future = loop.create_future()
81
+
82
+ def timeout_callback():
83
+ if not future.done():
84
+ future.set_exception(Exception("Queue timeout"))
85
+
86
+ timer = loop.call_later(timeout, timeout_callback)
87
+ self.waiting_resolvers.append((future, timer))
88
+
89
+ try:
90
+ return await future
91
+ finally:
92
+ if (future, timer) in self.waiting_resolvers:
93
+ self.waiting_resolvers.remove((future, timer))
94
+ timer.cancel()
95
+
96
+ def close(self):
97
+ self.closed = True
98
+ for future, timer in self.waiting_resolvers:
99
+ timer.cancel()
100
+ if not future.done():
101
+ future.set_exception(Exception("Queue closed"))
102
+ self.waiting_resolvers.clear()
103
+ self.messages.clear()
104
+
105
+
106
+ # WebSocket连接管理器
107
+ class ConnectionRegistry:
108
+ def __init__(self, logger: LoggingService):
109
+ self.logger = logger
110
+ self.main_connections: Set[WebSocket] = set() # 主连接集合
111
+ self.request_connections: Dict[str, WebSocket] = {} # 请求ID到专用连接的映射
112
+ self.message_queues: Dict[str, MessageQueue] = {}
113
+ self._connection_added_callbacks = []
114
+ self._connection_removed_callbacks = []
115
+
116
+ def on_connection_added(self, callback):
117
+ self._connection_added_callbacks.append(callback)
118
+
119
+ def on_connection_removed(self, callback):
120
+ self._connection_removed_callbacks.append(callback)
121
+
122
+ async def add_main_connection(self, websocket: WebSocket, client_info: dict):
123
+ """添加主WebSocket连接"""
124
+ await websocket.accept()
125
+ self.main_connections.add(websocket)
126
+ self.logger.info(f"新主连接: {client_info.get('address')}")
127
+
128
+ # 触发连接添加事件
129
+ for callback in self._connection_added_callbacks:
130
+ callback(websocket)
131
+
132
+ async def add_request_connection(self, websocket: WebSocket, request_id: str, client_info: dict):
133
+ """添加请求专用WebSocket连接"""
134
+ await websocket.accept()
135
+ self.request_connections[request_id] = websocket
136
+ self.logger.info(f"新请求连接 [ID: {request_id}]: {client_info.get('address')}")
137
+
138
+ async def remove_main_connection(self, websocket: WebSocket):
139
+ """移除主WebSocket连接"""
140
+ if websocket in self.main_connections:
141
+ self.main_connections.remove(websocket)
142
+ self.logger.info("主连接断开")
143
+
144
+ # 触发连接移除事件
145
+ for callback in self._connection_removed_callbacks:
146
+ callback(websocket)
147
+
148
+ async def remove_request_connection(self, websocket: WebSocket, request_id: str):
149
+ """移除请求专用WebSocket连接"""
150
+ if request_id in self.request_connections and self.request_connections[request_id] == websocket:
151
+ del self.request_connections[request_id]
152
+ self.logger.info(f"请求连接断开 [ID: {request_id}]")
153
+
154
+ # 关闭相关的消息队列
155
+ queue = self.message_queues.get(request_id)
156
+ if queue:
157
+ queue.close()
158
+ del self.message_queues[request_id]
159
+
160
+ async def handle_main_message(self, message_data: str):
161
+ """处理来自主连接的消息"""
162
+ try:
163
+ parsed_message = json.loads(message_data)
164
+ request_id = parsed_message.get("request_id")
165
+
166
+ if not request_id:
167
+ self.logger.warn("收到无效消息:缺少request_id")
168
+ return
169
+
170
+ # 主连接只接收初始请求,不处理响应
171
+ self.logger.info(f"收到主连接请求 [ID: {request_id}]")
172
+ except Exception as error:
173
+ self.logger.error(f"解析主连接WebSocket消息失败: {str(error)}")
174
+
175
+ async def handle_request_message(self, message_data: str, request_id: str):
176
+ """处理来自请求专用连接的消息"""
177
+ try:
178
+ parsed_message = json.loads(message_data)
179
+ message_request_id = parsed_message.get("request_id")
180
+
181
+ if not message_request_id:
182
+ self.logger.warn("收到无效消息:缺少request_id")
183
+ return
184
+
185
+ if message_request_id != request_id:
186
+ self.logger.warn(f"请求ID不匹配: 预期 {request_id}, 实际 {message_request_id}")
187
+ return
188
+
189
+ queue = self.message_queues.get(request_id)
190
+ if queue:
191
+ await self._route_message(parsed_message, queue)
192
+ else:
193
+ self.logger.warn(f"收到未知请求ID的消息: {request_id}")
194
+ except Exception as error:
195
+ self.logger.error(f"解析请求连接WebSocket消息失败: {str(error)}")
196
+
197
+ async def _route_message(self, message: dict, queue: MessageQueue):
198
+ event_type = message.get("event_type")
199
+
200
+ if event_type in ["response_headers", "chunk", "error"]:
201
+ queue.enqueue(message)
202
+ elif event_type == "stream_close":
203
+ queue.enqueue({"type": "STREAM_END"})
204
+ else:
205
+ self.logger.warn(f"未知的事件类型: {event_type}")
206
+
207
+ def has_active_main_connections(self) -> bool:
208
+ """检查是否有活跃的主连接"""
209
+ return len(self.main_connections) > 0
210
+
211
+ def get_random_main_connection(self) -> Optional[WebSocket]:
212
+ """随机获取一个主连接"""
213
+ if not self.main_connections:
214
+ return None
215
+
216
+ connections = list(self.main_connections)
217
+ random_index = random.randint(0, len(connections) - 1)
218
+ self.logger.info(f"随机选择主连接 {random_index + 1}/{len(connections)}")
219
+ return connections[random_index]
220
+
221
+ def get_request_connection(self, request_id: str) -> Optional[WebSocket]:
222
+ """获取指定请求ID的专用连接"""
223
+ return self.request_connections.get(request_id)
224
+
225
+ def create_message_queue(self, request_id: str) -> MessageQueue:
226
+ """为请求创建消息队列"""
227
+ queue = MessageQueue()
228
+ self.message_queues[request_id] = queue
229
+ return queue
230
+
231
+ def remove_message_queue(self, request_id: str):
232
+ """移除请求的消息队列"""
233
+ queue = self.message_queues.get(request_id)
234
+ if queue:
235
+ queue.close()
236
+ del self.message_queues[request_id]
237
+
238
+
239
+ # 请求处理器
240
+ class RequestHandler:
241
+ def __init__(self, connection_registry: ConnectionRegistry, logger: LoggingService):
242
+ self.connection_registry = connection_registry
243
+ self.logger = logger
244
+
245
+ async def process_request(self, request: Request) -> StreamingResponse:
246
+ self.logger.info(f"处理请求: {request.method} {request.url.path}")
247
+
248
+ if not self.connection_registry.has_active_main_connections():
249
+ raise HTTPException(status_code=503, detail="没有可用的浏览器连接")
250
+
251
+ request_id = self._generate_request_id()
252
+ proxy_request = await self._build_proxy_request(request, request_id)
253
+
254
+ message_queue = self.connection_registry.create_message_queue(request_id)
255
+
256
+ try:
257
+ # 通过主连接发送请求信息,包含请求专用WebSocket的URL
258
+ await self._notify_main_connection(proxy_request, request_id)
259
+
260
+ # 等待请求专用连接建立
261
+ await self._wait_for_request_connection(request_id)
262
+
263
+ # 通过请求专用连接转发请求
264
+ await self._forward_request(proxy_request, request_id)
265
+
266
+ return await self._handle_response(request, message_queue, request_id)
267
+ except Exception as error:
268
+ # 只在出错时清理队列
269
+ self.connection_registry.remove_message_queue(request_id)
270
+ if str(error) == "Queue timeout":
271
+ raise HTTPException(status_code=504, detail="请求超时")
272
+ elif str(error) == "请求连接建立超时":
273
+ raise HTTPException(status_code=504, detail="请求连接建立超时")
274
+ else:
275
+ # 检查是否是 HTTP 400 INVALID_ARGUMENT 错误,如果是则打印完整请求
276
+ if "HTTP 400" in str(error) and "INVALID_ARGUMENT" in str(error):
277
+ self.logger.error(f"[RequestProcessor] 请求执行失败: {str(error)}")
278
+ self.logger.error(f"完整请求信息:")
279
+ self.logger.error(f" 请求ID: {request_id}")
280
+ self.logger.error(f" 方法: {proxy_request['method']}")
281
+ self.logger.error(f" 路径: {proxy_request['path']}")
282
+ self.logger.error(f" 请求头: {json.dumps(proxy_request['headers'], indent=2, ensure_ascii=False)}")
283
+ self.logger.error(f" 查询参数: {json.dumps(proxy_request['query_params'], indent=2, ensure_ascii=False)}")
284
+ self.logger.error(f" 请求体: {proxy_request['body']}")
285
+ else:
286
+ self.logger.error(f"请求处理错误: {str(error)}")
287
+ raise HTTPException(status_code=500, detail=f"代理错误: {str(error)}")
288
+
289
+ def _generate_request_id(self) -> str:
290
+ return f"{int(time.time() * 1000)}_{random.getrandbits(32):08x}"
291
+
292
+ async def _build_proxy_request(self, request: Request, request_id: str) -> dict:
293
+ body = ""
294
+ body_data = await request.body()
295
+ if body_data:
296
+ try:
297
+ body = body_data.decode('utf-8')
298
+ except UnicodeDecodeError:
299
+ body = str(body_data)
300
+
301
+ return {
302
+ "path": request.url.path,
303
+ "method": request.method,
304
+ "headers": dict(request.headers),
305
+ "query_params": dict(request.query_params),
306
+ "body": body,
307
+ "request_id": request_id,
308
+ }
309
+
310
+ async def _notify_main_connection(self, proxy_request: dict, request_id: str):
311
+ """通知主连接有新请求"""
312
+ connection = self.connection_registry.get_random_main_connection()
313
+ if not connection:
314
+ raise Exception("没有可用的主连接")
315
+
316
+ # 发送完整的请求信息到主连接
317
+ await connection.send_text(json.dumps(proxy_request))
318
+ self.logger.info(f"已通知主连接新请求 [ID: {request_id}]")
319
+
320
+ async def _wait_for_request_connection(self, request_id: str, timeout: int = 30):
321
+ """等待请求专用连接建立"""
322
+ start_time = time.time()
323
+ while time.time() - start_time < timeout:
324
+ if self.connection_registry.get_request_connection(request_id):
325
+ self.logger.info(f"请求连接已建立 [ID: {request_id}]")
326
+ return
327
+ await asyncio.sleep(0.1)
328
+
329
+ self.logger.error(f"请求连接建立超时 [ID: {request_id}],已等待 {timeout} 秒")
330
+ raise Exception("请求连接建立超时")
331
+
332
+ async def _forward_request(self, proxy_request: dict, request_id: str):
333
+ """通过请求专用连接转发请求"""
334
+ connection = self.connection_registry.get_request_connection(request_id)
335
+ if not connection:
336
+ raise Exception(f"请求连接不存在 [ID: {request_id}]")
337
+
338
+ await connection.send_text(json.dumps(proxy_request))
339
+ self.logger.info(f"请求已转发到专用连接 [ID: {request_id}]")
340
+
341
+ async def _handle_response(self, request: Request, message_queue: MessageQueue, request_id: str) -> StreamingResponse:
342
+ # 等待响应头
343
+ try:
344
+ header_message = await message_queue.dequeue()
345
+ except Exception as e:
346
+ raise HTTPException(status_code=500, detail=f"获取响应头失败: {str(e)}")
347
+
348
+ if header_message.get("event_type") == "error":
349
+ error_status = header_message.get("status", 500)
350
+ error_message = header_message.get("message", "未知错误")
351
+
352
+ # 检查是否是 HTTP 400 INVALID_ARGUMENT 错误,如果是则打印完整请求
353
+ if error_status == 400 and "INVALID_ARGUMENT" in error_message:
354
+ self.logger.error(f"[RequestProcessor] 请求执行失败: HTTP {error_status}: {error_message}")
355
+ self.logger.error(f"完整请求信息:")
356
+ self.logger.error(f" 请求ID: {request_id}")
357
+ # 需要从请求中重新构建 proxy_request 信息
358
+ proxy_request = await self._build_proxy_request(request, request_id)
359
+ self.logger.error(f" 方法: {proxy_request['method']}")
360
+ self.logger.error(f" 路径: {proxy_request['path']}")
361
+ self.logger.error(f" 请求头: {json.dumps(proxy_request['headers'], indent=2, ensure_ascii=False)}")
362
+ self.logger.error(f" 查询参数: {json.dumps(proxy_request['query_params'], indent=2, ensure_ascii=False)}")
363
+ self.logger.error(f" 请求体: {proxy_request['body']}")
364
+
365
+ raise HTTPException(
366
+ status_code=error_status,
367
+ detail=error_message
368
+ )
369
+
370
+ # 设置响应头
371
+ headers = header_message.get("headers", {})
372
+ status_code = header_message.get("status", 200)
373
+
374
+ # 创建流式响应
375
+ return StreamingResponse(
376
+ self._stream_response_generator(message_queue, headers, request_id),
377
+ status_code=status_code,
378
+ headers=headers
379
+ )
380
+
381
+ async def _stream_response_generator(self, message_queue: MessageQueue, headers: dict, request_id: str):
382
+ try:
383
+ while True:
384
+ try:
385
+ data_message = await message_queue.dequeue()
386
+
387
+ if data_message.get("type") == "STREAM_END":
388
+ self.logger.debug(f"收到流结束信号 [ID: {request_id}]")
389
+ break
390
+
391
+ if data_message.get("event_type") == "error":
392
+ self.logger.error(f"收到错误信号 [ID: {request_id}]: {data_message.get('message', '未知错误')}")
393
+ break
394
+
395
+ if data := data_message.get("data"):
396
+ if isinstance(data, str):
397
+ yield data.encode('utf-8')
398
+ else:
399
+ yield data
400
+
401
+ except Exception as error:
402
+ if str(error) == "Queue timeout":
403
+ content_type = headers.get("Content-Type", "")
404
+ if "text/event-stream" in content_type:
405
+ yield b": keepalive\n\n"
406
+ else:
407
+ self.logger.debug(f"队列超时,结束流式响应 [ID: {request_id}]")
408
+ break
409
+ elif str(error) in ["Queue closed", "Queue is closed"]:
410
+ self.logger.info(f"队列已关闭,结束流式响应 [ID: {request_id}]")
411
+ break
412
+ else:
413
+ self.logger.error(f"流式响应处理错误 [ID: {request_id}]: {str(error)}")
414
+ raise error
415
+ except Exception as e:
416
+ self.logger.error(f"流式响应生成错误 [ID: {request_id}]: {str(e)}")
417
+ finally:
418
+ # 流式响应结束后清理资源
419
+ self.logger.debug(f"流式响应结束,开始清理资源 [ID: {request_id}]")
420
+
421
+ # 清理消息队列
422
+ self.connection_registry.remove_message_queue(request_id)
423
+
424
+ # 清理请求专用连接
425
+ connection = self.connection_registry.get_request_connection(request_id)
426
+ if connection:
427
+ try:
428
+ await connection.close()
429
+ self.logger.debug(f"请求连接已关闭 [ID: {request_id}]")
430
+ except Exception as e:
431
+ self.logger.error(f"关闭请求连接失败 [ID: {request_id}]: {str(e)}")
432
+
433
+
434
+ # 主服务器类
435
+ class ProxyServerSystem:
436
+ def __init__(self, config: dict = None):
437
+ if config is None:
438
+ config = {}
439
+
440
+ # 从环境变量获取端口,Hugging Face Spaces 使用 PORT 环境变量
441
+ port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces 默认端口
442
+ host = os.environ.get("HOST", "0.0.0.0")
443
+
444
+ self.config = {
445
+ "http_port": port,
446
+ "ws_port": port, # 使用同一个端口
447
+ "host": host,
448
+ **config
449
+ }
450
+
451
+ self.logger = LoggingService("ProxyServer")
452
+ self.connection_registry = ConnectionRegistry(self.logger)
453
+ self.request_handler = RequestHandler(self.connection_registry, self.logger)
454
+
455
+ self.app = FastAPI(
456
+ title="WebSocket Proxy Server",
457
+ description="A proxy server with WebSocket support for Hugging Face Spaces",
458
+ version="1.0.0"
459
+ )
460
+ self._setup_routes()
461
+ self._started_callbacks = []
462
+ self._error_callbacks = []
463
+
464
+ def on_started(self, callback):
465
+ self._started_callbacks.append(callback)
466
+
467
+ def on_error(self, callback):
468
+ self._error_callbacks.append(callback)
469
+
470
+ def _setup_routes(self):
471
+ # 健康检查端点
472
+ @self.app.get("/health")
473
+ async def health_check():
474
+ return {
475
+ "status": "healthy",
476
+ "timestamp": datetime.now().isoformat(),
477
+ "connections": {
478
+ "main": len(self.connection_registry.main_connections),
479
+ "requests": len(self.connection_registry.request_connections)
480
+ }
481
+ }
482
+
483
+ # 根路径返回简单的 HTML 页面
484
+ @self.app.get("/", response_class=HTMLResponse)
485
+ async def root():
486
+ html_content = """
487
+ <!DOCTYPE html>
488
+ <html>
489
+ <head>
490
+ <title>WebSocket Proxy Server</title>
491
+ <meta charset="utf-8">
492
+ <style>
493
+ body { font-family: Arial, sans-serif; margin: 40px; }
494
+ .container { max-width: 800px; margin: 0 auto; }
495
+ .status { padding: 10px; border-radius: 5px; margin: 10px 0; }
496
+ .success { background-color: #d4edda; color: #155724; }
497
+ .info { background-color: #d1ecf1; color: #0c5460; }
498
+ code { background-color: #f8f9fa; padding: 2px 4px; border-radius: 3px; }
499
+ </style>
500
+ </head>
501
+ <body>
502
+ <div class="container">
503
+ <h1>WebSocket Proxy Server</h1>
504
+ <div class="status success">
505
+ ✅ 服务器运行正常
506
+ </div>
507
+
508
+ <h2>连接信息</h2>
509
+ <div class="info">
510
+ <p><strong>主 WebSocket 连接:</strong> <code>ws://your-space-url/ws</code></p>
511
+ <p><strong>请求专用连接:</strong> <code>ws://your-space-url/ws/request/{request_id}</code></p>
512
+ <p><strong>健康检查:</strong> <code>/health</code></p>
513
+ </div>
514
+
515
+ <h2>使用说明</h2>
516
+ <ol>
517
+ <li>首先建立主 WebSocket 连接到 <code>/ws</code></li>
518
+ <li>发送 HTTP 请求到任意路径</li>
519
+ <li>服务器会通过主连接通知新请求</li>
520
+ <li>客户端需要建立请求专用连接到 <code>/ws/request/{request_id}</code></li>
521
+ <li>通过专用连接处理请求和响应</li>
522
+ </ol>
523
+
524
+ <h2>环境信息</h2>
525
+ <p>运行在 Hugging Face Spaces 环境中</p>
526
+ </div>
527
+ </body>
528
+ </html>
529
+ """
530
+ return HTMLResponse(content=html_content)
531
+
532
+ # 主WebSocket路由 - 接收初始请求
533
+ @self.app.websocket("/ws")
534
+ async def main_websocket_endpoint(websocket: WebSocket):
535
+ client_info = {
536
+ "address": websocket.client.host if websocket.client else "unknown"
537
+ }
538
+
539
+ await self.connection_registry.add_main_connection(websocket, client_info)
540
+
541
+ try:
542
+ while True:
543
+ message = await websocket.receive_text()
544
+ await self.connection_registry.handle_main_message(message)
545
+ except WebSocketDisconnect:
546
+ self.logger.info("主WebSocket连接已关闭")
547
+ except Exception as e:
548
+ self.logger.error(f"主WebSocket处理错误: {str(e)}")
549
+ finally:
550
+ await self.connection_registry.remove_main_connection(websocket)
551
+
552
+ # 请求专用WebSocket路由 - 处理独立请求
553
+ @self.app.websocket("/ws/request/{request_id}")
554
+ async def request_websocket_endpoint(websocket: WebSocket, request_id: str):
555
+ client_info = {
556
+ "address": websocket.client.host if websocket.client else "unknown"
557
+ }
558
+
559
+ await self.connection_registry.add_request_connection(websocket, request_id, client_info)
560
+
561
+ try:
562
+ while True:
563
+ message = await websocket.receive_text()
564
+ await self.connection_registry.handle_request_message(message, request_id)
565
+ except WebSocketDisconnect:
566
+ self.logger.info(f"请求WebSocket连接已关闭 [ID: {request_id}]")
567
+ except Exception as e:
568
+ self.logger.error(f"请求WebSocket处理错误 [ID: {request_id}]: {str(e)}")
569
+ finally:
570
+ await self.connection_registry.remove_request_connection(websocket, request_id)
571
+
572
+ # API 路由前缀,避免与根路径冲突
573
+ @self.app.api_route("/api/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
574
+ async def api_proxy(request: Request, path: str):
575
+ return await self.request_handler.process_request(request)
576
+
577
+ # 通配符路由处理其他HTTP请求(排除根路径和健康检查)
578
+ @self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
579
+ async def catch_all(request: Request, path: str):
580
+ # 排除特殊路径
581
+ if path in ["", "health", "ws"] or path.startswith("ws/"):
582
+ raise HTTPException(status_code=404, detail="Not Found")
583
+ return await self.request_handler.process_request(request)
584
+
585
+ async def start(self):
586
+ try:
587
+ # 启动HTTP服务器
588
+ config = uvicorn.Config(
589
+ app=self.app,
590
+ host=self.config["host"],
591
+ port=self.config["http_port"],
592
+ log_level="info",
593
+ access_log=True
594
+ )
595
+ server = uvicorn.Server(config)
596
+
597
+ self.logger.info(f"HTTP服务器启动: http://{self.config['host']}:{self.config['http_port']}")
598
+ self.logger.info(f"主WebSocket服务器启动: ws://{self.config['host']}:{self.config['http_port']}/ws")
599
+ self.logger.info(f"请求WebSocket服务器启动: ws://{self.config['host']}:{self.config['http_port']}/ws/request/{{request_id}}")
600
+ self.logger.info("代理服务器系统启动完成 - 适配 Hugging Face Spaces")
601
+
602
+ # 触发启动事件
603
+ for callback in self._started_callbacks:
604
+ callback()
605
+
606
+ # 启动服务器
607
+ await server.serve()
608
+
609
+ except Exception as error:
610
+ self.logger.error(f"启动失败: {str(error)}")
611
+
612
+ # 触发错误事件
613
+ for callback in self._error_callbacks:
614
+ callback(error)
615
+
616
+ raise error
617
+
618
+
619
+ # 启动函数
620
+ async def initialize_server():
621
+ server_system = ProxyServerSystem()
622
+
623
+ try:
624
+ await server_system.start()
625
+ except Exception as error:
626
+ print(f"服务器启动失败: {str(error)}")
627
+ raise
628
+
629
+
630
+ # 主程序入口
631
+ if __name__ == "__main__":
632
+ try:
633
+ asyncio.run(initialize_server())
634
+ except KeyboardInterrupt:
635
+ print("服务器已停止")