letterm commited on
Commit
5722028
·
verified ·
1 Parent(s): b47441f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -396
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from google.protobuf import descriptor as _descriptor
2
  from google.protobuf import descriptor_pool as _descriptor_pool
3
  from google.protobuf import runtime_version as _runtime_version
@@ -21,123 +22,81 @@ _globals = globals()
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'zed_pb2', _globals)
23
  if not _descriptor._USE_C_DESCRIPTORS:
24
- DESCRIPTOR._loaded_options = None
25
- _globals['_ERRORCODE']._serialized_start=1085
26
- _globals['_ERRORCODE']._serialized_end=1452
27
- _globals['_PEERID']._serialized_start=27
28
- _globals['_PEERID']._serialized_end=65
29
- _globals['_ENVELOPE']._serialized_start=68
30
- _globals['_ENVELOPE']._serialized_end=806
31
- _globals['_HELLO']._serialized_start=808
32
- _globals['_HELLO']._serialized_end=854
33
- _globals['_PING']._serialized_start=856
34
- _globals['_PING']._serialized_end=862
35
- _globals['_ACK']._serialized_start=864
36
- _globals['_ACK']._serialized_end=869
37
- _globals['_ERROR']._serialized_start=871
38
- _globals['_ERROR']._serialized_end=948
39
- _globals['_ACCEPTTERMSOFSERVICE']._serialized_start=950
40
- _globals['_ACCEPTTERMSOFSERVICE']._serialized_end=972
41
- _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_start=974
42
- _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_end=1029
43
- _globals['_GETLLMTOKEN']._serialized_start=1031
44
- _globals['_GETLLMTOKEN']._serialized_end=1044
45
- _globals['_GETLLMTOKENRESPONSE']._serialized_start=1046
46
- _globals['_GETLLMTOKENRESPONSE']._serialized_end=1082
47
-
48
-
49
  import os
50
  import json
51
  import ssl
52
  import time
53
  import asyncio
 
54
  import aiohttp
55
- import sys
56
- import inspect
57
- from loguru import logger
58
  from aiohttp import web
59
  import zstandard as zstd
 
60
  from websockets.asyncio.client import connect
61
  from websockets.exceptions import ConnectionClosed
 
62
 
63
  from google.protobuf.json_format import MessageToDict
64
 
65
- # 修复
66
  Envelope = _sym_db.GetSymbol('zed.messages.Envelope')
67
 
68
- class Logger:
69
- def __init__(self, level="INFO", colorize=True, format=None):
70
- logger.remove()
71
-
72
- if format is None:
73
- format = (
74
- "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
75
- "<level>{level: <8}</level> | "
76
- "<cyan>{extra[filename]}</cyan>:<cyan>{extra[function]}</cyan>:<cyan>{extra[lineno]}</cyan> | "
77
- "<level>{message}</level>"
78
- )
79
-
80
- logger.add(
81
- sys.stderr,
82
- level=level,
83
- format=format,
84
- colorize=colorize,
85
- backtrace=True,
86
- diagnose=True
87
- )
88
-
89
- self.logger = logger
90
-
91
- def _get_caller_info(self):
92
- frame = inspect.currentframe()
93
- try:
94
- caller_frame = frame.f_back.f_back
95
- full_path = caller_frame.f_code.co_filename
96
- function = caller_frame.f_code.co_name
97
- lineno = caller_frame.f_lineno
98
-
99
- filename = os.path.basename(full_path)
100
-
101
- return {
102
- 'filename': filename,
103
- 'function': function,
104
- 'lineno': lineno
105
- }
106
- finally:
107
- del frame
108
-
109
- def info(self, message, source="API"):
110
- caller_info = self._get_caller_info()
111
- self.logger.bind(**caller_info).info(f"[{source}] {message}")
112
-
113
- def error(self, message, source="API"):
114
- caller_info = self._get_caller_info()
115
-
116
- if isinstance(message, Exception):
117
- self.logger.bind(**caller_info).exception(f"[{source}] {str(message)}")
118
- else:
119
- self.logger.bind(**caller_info).error(f"[{source}] {message}")
120
-
121
- def warning(self, message, source="API"):
122
- caller_info = self._get_caller_info()
123
- self.logger.bind(**caller_info).warning(f"[{source}] {message}")
124
-
125
- def debug(self, message, source="API"):
126
- caller_info = self._get_caller_info()
127
- self.logger.bind(**caller_info).debug(f"[{source}] {message}")
128
 
129
- async def request_logger(self, request):
130
- caller_info = self._get_caller_info()
131
- self.logger.bind(**caller_info).info(f"请求: {request.method} {request.path}", "Request")
132
 
133
- logger = Logger(level="INFO")
134
 
135
- BASE_URL = "https://zed.dev"
136
- BASE_API_URL = "https://collab.zed.dev"
137
- WS_URL = "wss://collab.zed.dev/rpc"
138
- LLM_API_URL = "https://llm.zed.dev/completion"
139
- DEFAULT_PROXY_PORT = 7860 # 默认端口,可通过环境变量覆盖
140
- TOKEN_EXPIRY_WARNING_MINUTES = 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  highest_message_id = 0
143
  llm_token = None
@@ -146,47 +105,40 @@ server_peer_id = None
146
  active_websocket = None
147
  proxy_server_running = False
148
 
149
- # 从环境变量读取API密码和自定义端口
150
- PROXY_PASSWORD = os.environ.get("ZED_PROXY_PASSWORD")
151
- PROXY_PORT = int(os.environ.get("ZED_PROXY_PORT", DEFAULT_PROXY_PORT))
152
-
153
- # 支持的模型列表
154
- AVAILABLE_MODELS = [
155
- {
156
- "id": "claude-3-5-sonnet-20240620",
157
- "object": "model",
158
- "created": 1719158400,
159
- "owned_by": "anthropic",
160
- "permission": [],
161
- },
162
- {
163
- "id": "claude-3-5-sonnet-20241022",
164
- "object": "model",
165
- "created": 1729344000,
166
- "owned_by": "anthropic",
167
- "permission": [],
168
- },
169
- {
170
- "id": "claude-3-7-sonnet-20250219",
171
- "object": "model",
172
- "created": 1740153600,
173
- "owned_by": "anthropic",
174
- "permission": [],
175
- }
176
- ]
177
-
178
- # 模型映射表(从其他可能的模型名称映射到我们支持的模型)
179
- MODEL_MAPPING = {
180
- # 默认映射到最新版本
181
- "claude-3-5-sonnet": "claude-3-5-sonnet-20241022",
182
- "claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
183
- "claude-3-7-sonnet": "claude-3-7-sonnet-20250219",
184
- "claude-3.7-sonnet": "claude-3-7-sonnet-20250219",
185
- # 确保标准名称也能映射到自身
186
- "claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620",
187
- "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-20241022",
188
- "claude-3-7-sonnet-20250219": "claude-3-7-sonnet-20250219",
189
- }
190
 
191
  def decode_envelope(data):
192
  try:
@@ -203,11 +155,11 @@ def decode_envelope(data):
203
  return MessageToDict(envelope, preserving_proto_field_name=True)
204
  except Exception as e:
205
  hex_preview = ' '.join(f'{byte:02x}' for byte in data[:20]) + ('...' if len(data) > 20 else '')
206
- logger.error(f"Unable to decode message: {e}; data preview: {hex_preview}")
207
- return {"error": f"Unable to decode message: {e}"}
208
 
209
  def compress_protobuf(data):
210
- return zstd.ZstdCompressor(level=7).compress(data)
211
 
212
  def create_message(message_type):
213
  global highest_message_id
@@ -224,7 +176,7 @@ async def ping_periodically(websocket):
224
  await websocket.ping()
225
  await asyncio.sleep(1)
226
  except Exception as e:
227
- logger.error(f"Error sending ping: {e}")
228
  break
229
 
230
  async def handle_messages(websocket):
@@ -232,36 +184,36 @@ async def handle_messages(websocket):
232
  active_websocket = websocket
233
  try:
234
  async for message in websocket:
235
- message_bytes = message.encode('utf8') if isinstance(message, str) else message
236
  decoded = decode_envelope(message_bytes)
237
  if "hello" in decoded:
238
  server_peer_id = decoded.get('hello', {}).get('peer_id')
239
  elif "accept_terms_of_service_response" in decoded:
240
  await request_llm_token(websocket)
241
- elif ("get_llm_token_response" in decoded and
242
  'token' in decoded.get('get_llm_token_response', {})):
243
  llm_token = decoded['get_llm_token_response']['token']
244
  token_timestamp = time.time()
245
- logger.info(f"LLM token received at {time.ctime(token_timestamp)}")
246
  if not proxy_server_running:
247
  asyncio.create_task(start_proxy_server())
248
  asyncio.create_task(monitor_token_expiration())
249
- logger.info("Closing WebSocket connection until token refresh is needed")
250
  await websocket.close()
251
  active_websocket = None
252
  return
253
  except ConnectionClosed:
254
- logger.info("Connection closed")
255
  active_websocket = None
256
 
257
  async def request_llm_token(websocket):
258
  message, _ = create_message('get_llm_token')
259
- logger.info("Requesting the LLM token")
260
  await websocket.send(message)
261
 
262
  async def request_accept_terms_of_service(websocket):
263
  message, _ = create_message('accept_terms_of_service')
264
- logger.info("Sending consent for the Zed Terms of Service")
265
  await websocket.send(message)
266
 
267
  def format_content(content):
@@ -269,129 +221,165 @@ def format_content(content):
269
  return [{"type": "text", "text": content}]
270
  return content
271
 
272
- # 密码验证中间件
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  @web.middleware
274
  async def auth_middleware(request, handler):
275
- if PROXY_PASSWORD:
276
- # 检查各种可能的密码验证方式
277
  auth_header = request.headers.get('Authorization')
278
- password_header = request.headers.get('XPassword') or request.headers.get('Password')
279
 
280
- # 从Authorization头获取密码(如果是Bearer格式)
281
  auth_password = None
282
  if auth_header and auth_header.startswith('Bearer '):
283
  auth_password = auth_header[7:]
284
 
285
- # 检查任何来源的密码是否匹配
286
- if auth_password == PROXY_PASSWORD or password_header == PROXY_PASSWORD:
287
  return await handler(request)
288
  else:
289
  return web.json_response(
290
- {"error": "Unauthorized. Provide valid password in Authorization or XPassword header"},
291
  status=401
292
  )
293
- # 如果未设置密码,直接通过验证
294
- return await handler(request)
295
-
296
- def get_mapped_model(model_name):
297
- """将各种模型名称映射到支持的模型"""
298
- if not model_name:
299
- return "claude-3-5-sonnet-20241022" # 默认模型
300
-
301
- # 如果模型名称在映射表中,返回映射的值
302
- return MODEL_MAPPING.get(model_name, "claude-3-5-sonnet-20241022")
303
-
304
- def convert_openai_to_anthropic(data):
305
- """将OpenAI API请求格式转换为Anthropic格式"""
306
- # 如果已经是Anthropic格式,不做转换
307
- if "messages" in data:
308
- return data
309
-
310
- result = {}
311
-
312
- # 处理模型
313
- if "model" in data:
314
- result["model"] = get_mapped_model(data["model"])
315
- else:
316
- result["model"] = "claude-3-5-sonnet-20241022"
317
-
318
- # 处理系统消息和用户消息
319
- result["messages"] = []
320
- if "system" in data:
321
- result["system"] = data["system"]
322
-
323
- # 处理消息数组
324
- if "messages" in data:
325
- result["messages"] = data["messages"]
326
 
327
- # 复制其他参数
328
- for key in ["temperature", "top_p", "max_tokens", "stream"]:
329
- if key in data:
330
- result[key] = data[key]
331
-
332
- return result
333
 
334
  async def handle_models_request(request):
335
- """处理/v1/models端点的请求,返回支持的模型列表"""
336
  return web.json_response({
337
  "object": "list",
338
- "data": AVAILABLE_MODELS
 
 
 
 
 
 
 
 
339
  })
340
 
341
  async def handle_message_request(request):
342
  global llm_token
343
  if not llm_token:
344
- return web.json_response({"error": "LLM token not available"}, status=500)
345
  try:
346
  body = await request.json()
347
-
348
- # 转换格式(如果需要)
349
- body = convert_openai_to_anthropic(body)
350
-
351
- # 确保模型是支持的
352
- if "model" in body:
353
- body["model"] = get_mapped_model(body["model"])
354
-
355
- # Zed的Anthropic API对消息内容格式有特定要求
356
- if "messages" in body:
357
- for msg in body["messages"]:
358
- if "content" in msg:
359
- msg["content"] = format_content(msg["content"])
360
-
361
- # 处理系统消息格式
362
- if "system" in body:
363
- if isinstance(body["system"], list):
364
- body["system"] = "\n".join([item["text"] for item in body["system"]])
365
 
366
  headers = {"Content-Type": "application/json", "Authorization": f"Bearer {llm_token}"}
367
  payload = {
368
  "provider": "anthropic",
369
- "model": body.get("model", "claude-3-5-sonnet-20241022"),
370
  "provider_request": body
371
  }
372
-
373
-
374
  if body.get("stream", False):
375
- return await handle_streaming_request(request, headers, payload)
376
  else:
377
- return await handle_non_streaming_request(headers, payload)
378
  except Exception as e:
379
- logger.error(f"Error processing request: {e}")
380
  return web.json_response({"error": str(e)}, status=500)
381
 
382
- async def handle_non_streaming_request(headers, payload):
383
  async with aiohttp.ClientSession() as session:
384
- async with session.post(LLM_API_URL, headers=headers, json=payload) as r:
385
  if r.status != 200:
386
  text = await r.text()
387
- logger.error(f"LLM API error: {text}")
388
  return web.json_response({"error": text}, status=r.status)
389
  full_content, message_data = "", {}
390
  async for line in r.content:
391
  if not line:
392
  continue
393
  try:
394
- event = json.loads(line.decode('utf8').strip())
395
  et = event.get('type')
396
  if et == "message_start":
397
  message_data = event.get('message', {})
@@ -403,195 +391,108 @@ async def handle_non_streaming_request(headers, payload):
403
  break
404
  except Exception as e:
405
  logger.error(f"Error processing line: {e}")
 
 
 
 
 
406
 
407
- # 构建响应
408
- message_data['content'] = [{"type": "text", "text": full_content}]
409
-
410
- # 为OpenAI格式添加兼容字段 - 使用完整模型名称
411
- model_name = payload.get("provider_request", {}).get("model", "claude-3-5-sonnet-20241022")
412
- result = {
413
- "id": f"chatcmpl{int(time.time()*1000)}",
414
- "object": "chat.completion",
415
- "created": int(time.time()),
416
- "model": model_name,
417
- "choices": [
418
- {
419
- "index": 0,
420
- "message": {
421
- "role": "assistant",
422
- "content": full_content
423
- },
424
- "finish_reason": "stop"
425
- }
426
- ],
427
- }
428
-
429
- # 如果有usage信息,添加到结果中
430
- if "usage" in message_data:
431
- result["usage"] = message_data["usage"]
432
-
433
- return web.json_response(result)
434
-
435
- async def handle_streaming_request(request, headers, payload):
436
  response = web.StreamResponse()
437
  response.headers['Content-Type'] = 'text/event-stream'
438
  response.headers['Cache-Control'] = 'no-cache'
439
  response.headers['Connection'] = 'keep-alive'
440
  await response.prepare(request)
441
-
442
  async with aiohttp.ClientSession() as session:
443
- async with session.post(LLM_API_URL, headers=headers, json=payload, timeout=60) as api_response:
444
  if api_response.status != 200:
445
  error_text = await api_response.text()
446
- logger.error(f"LLM API (stream) error: {error_text}")
447
  await response.write(f"data: {json.dumps({'error': error_text})}\n\n".encode())
448
  await response.write(b"data: [DONE]\n\n")
449
  return response
450
-
451
- # 为流式响应创建唯一ID
452
- response_id = f"chatcmpl{int(time.time()*1000)}"
453
- content_buffer = ""
454
- # 使用完整模型名称
455
- model_name = payload.get("provider_request", {}).get("model", "claude-3-5-sonnet-20241022")
456
-
457
  async for line in api_response.content:
458
- if line:
459
- line_text = line.decode('utf8')
460
-
461
- # 直接传递原始事件
462
- await response.write(f"data: {line_text}\n\n".encode())
463
-
464
- # 解析事件以构建OpenAI兼容格式
465
- try:
466
- data = json.loads(line_text)
467
- if data.get('type') == 'content_block_delta' and data.get('delta', {}).get('type') == 'text_delta':
468
- delta_text = data['delta'].get('text', '')
469
- content_buffer += delta_text
470
-
471
- # 创建OpenAI兼容的流式响应格式
472
- openai_chunk = {
473
- "id": response_id,
474
- "object": "chat.completion.chunk",
475
- "created": int(time.time()),
476
- "model": model_name,
477
- "choices": [
478
- {
479
- "index": 0,
480
- "delta": {
481
- "content": delta_text
482
- },
483
- "finish_reason": None
484
- }
485
- ]
486
- }
487
- # 发送OpenAI兼容格式
488
- await response.write(f"data: {json.dumps(openai_chunk)}\n\n".encode())
489
-
490
- # 消息结束时发送完成标记
491
- elif data.get('type') == 'message_stop':
492
- final_chunk = {
493
- "id": response_id,
494
- "object": "chat.completion.chunk",
495
- "created": int(time.time()),
496
- "model": model_name,
497
- "choices": [
498
- {
499
- "index": 0,
500
- "delta": {},
501
- "finish_reason": "stop"
502
- }
503
- ]
504
- }
505
- await response.write(f"data: {json.dumps(final_chunk)}\n\n".encode())
506
- except json.JSONDecodeError:
507
- pass
508
-
509
  await response.write(b"data: [DONE]\n\n")
 
510
 
511
- return response
512
 
513
  async def start_proxy_server():
514
  global proxy_server_running
515
  if proxy_server_running:
516
- logger.info("Proxy server already running, skipping startup")
517
  return
518
-
519
  proxy_server_running = True
520
-
521
- # 创建带有中间件的应用
522
  app = web.Application(middlewares=[auth_middleware])
523
-
524
- # 添加API端点
525
  app.router.add_post('/v1/messages', handle_message_request)
526
- app.router.add_post('/v1/chat/completions', handle_message_request) # 添加OpenAI兼容端点
527
- app.router.add_get('/v1/models', handle_models_request) # 添加模型列表端点
528
-
529
- # 添加一个简单的健康检查端点
530
- async def health_check(request):
531
  return web.json_response({
532
  "status": "ok",
533
- "message": "Zed LLM proxy is running",
534
- "available_models": [model["id"] for model in AVAILABLE_MODELS]
535
  })
536
 
537
  app.router.add_get('/', health_check)
538
 
539
  runner = web.AppRunner(app)
540
  await runner.setup()
541
- site = web.TCPSite(runner, 'localhost', PROXY_PORT)
542
  await site.start()
543
-
544
- password_status = "启用" if PROXY_PASSWORD else "未启用"
545
- logger.info(f"代理服务器启动于 http://localhost:{PROXY_PORT}")
546
- logger.info(f"密码保护: {password_status}")
547
- logger.info(f"支持的模型: {', '.join([model['id'] for model in AVAILABLE_MODELS])}")
548
- logger.info("支持的端点:")
549
- logger.info(" /v1/messages (Anthropic API)")
550
- logger.info(" /v1/chat/completions (OpenAI兼容API)")
551
- logger.info(" /v1/models (模型列表)")
552
-
553
  while True:
554
  await asyncio.sleep(3600)
555
 
556
  def is_token_expiring():
557
  if not token_timestamp:
558
  return False
559
- return (time.time() - token_timestamp) / 60 >= TOKEN_EXPIRY_WARNING_MINUTES
560
 
561
  async def monitor_token_expiration():
562
  while True:
563
  await asyncio.sleep(60)
564
  if is_token_expiring():
565
  elapsed = int((time.time() - token_timestamp) / 60)
566
- logger.warning(f"LLM token is approaching expiration (received {elapsed} minutes ago)")
567
  if active_websocket is None:
568
- logger.info("Reconnecting WebSocket for token refresh")
569
  asyncio.create_task(reconnect_for_token_refresh())
570
  return
571
 
572
  async def reconnect_for_token_refresh():
573
  try:
574
- # 使用环境变量而不是从文件读取
575
- user_id = os.environ.get("ZED_USER_ID")
576
- auth_token = os.environ.get("ZED_AUTH_TOKEN").replace("\\", "")
577
-
578
- if not user_id or not auth_token:
579
- logger.error("环境变量ZED_USER_ID或ZED_AUTH_TOKEN未设置")
580
  return
581
-
582
- # 确保header格式一致
583
  headers = {
584
- "authorization": user_id + " " + auth_token,
585
  "x-zed-protocol-version": "68",
586
  "x-zed-app-version": "0.178.0",
587
  "x-zed-release-channel": "stable"
588
  }
589
- print(headers)
590
  ssl_context = ssl.create_default_context()
591
  ssl_context.check_hostname = False
592
  ssl_context.verify_mode = ssl.CERT_NONE
593
-
594
- async for websocket in connect(WS_URL, additional_headers=headers, ssl=ssl_context):
595
  try:
596
  ping_task = asyncio.create_task(ping_periodically(websocket))
597
  await asyncio.sleep(2)
@@ -601,7 +502,7 @@ async def reconnect_for_token_refresh():
601
  except ConnectionClosed:
602
  continue
603
  except Exception as e:
604
- logger.error(f"Error during token refresh: {e}")
605
  await asyncio.sleep(1)
606
  continue
607
  finally:
@@ -611,31 +512,23 @@ async def reconnect_for_token_refresh():
611
  except asyncio.CancelledError:
612
  pass
613
  except Exception as e:
614
- logger.error(f"Failed to reconnect for token refresh: {e}")
615
 
616
  async def async_main():
617
- # 从环境变量获取认证信息
618
- user_id = os.environ.get("ZED_USER_ID")
619
- auth_token = os.environ.get("ZED_AUTH_TOKEN").replace("\\", "")
620
-
621
- if not user_id or not auth_token:
622
- logger.error("环境变量ZED_USER_ID或ZED_AUTH_TOKEN未设置")
623
- logger.error("请先运行auth_script.py获取认证信息,然后设置环境变量")
624
  return
625
-
626
- # 确保header格式一致
627
  headers = {
628
- "authorization": user_id + " " + auth_token,
629
  "x-zed-protocol-version": "68",
630
  "x-zed-app-version": "0.178.0",
631
  "x-zed-release-channel": "stable"
632
  }
633
- print(headers)
634
  ssl_context = ssl.create_default_context()
635
  ssl_context.check_hostname = False
636
  ssl_context.verify_mode = ssl.CERT_NONE
637
- logger.info("Connecting to the WebSocket server")
638
- async for websocket in connect(WS_URL, additional_headers=headers, ssl=ssl_context):
639
  try:
640
  ping_task = asyncio.create_task(ping_periodically(websocket))
641
  token_request_task = asyncio.create_task(delayed_token_request(websocket, delay=2))
@@ -644,7 +537,7 @@ async def async_main():
644
  except ConnectionClosed:
645
  continue
646
  except Exception as e:
647
- logger.error(f"Unexpected error: {e}")
648
  await asyncio.sleep(1)
649
  continue
650
  finally:
@@ -666,25 +559,5 @@ async def delayed_token_request(websocket, delay=2):
666
  await asyncio.sleep(delay)
667
  await request_accept_terms_of_service(websocket)
668
 
669
- def main():
670
- logger.info("启动Zed LLM代理服务")
671
- logger.info(f"配置端口: {PROXY_PORT}")
672
- logger.info(f"密码保护: {'已启用' if PROXY_PASSWORD else '未启用'}")
673
-
674
- # 检查必要的环境变量
675
- if not os.environ.get("ZED_USER_ID") or not os.environ.get("ZED_AUTH_TOKEN"):
676
- logger.error("错误: 环境变量ZED_USER_ID和ZED_AUTH_TOKEN必须设置")
677
- logger.error("请先运行auth_script.py获取认证信息,然后设置环境变量")
678
- return
679
-
680
- try:
681
- asyncio.run(async_main())
682
- except KeyboardInterrupt:
683
- logger.info("收到中断信号,正在退出...")
684
- except Exception as e:
685
- logger.error(f"发生错误: {e}")
686
- finally:
687
- logger.info("代理服务已关闭")
688
-
689
  if __name__ == "__main__":
690
- main()
 
1
+ # Generated from trimmed zed.proto
2
  from google.protobuf import descriptor as _descriptor
3
  from google.protobuf import descriptor_pool as _descriptor_pool
4
  from google.protobuf import runtime_version as _runtime_version
 
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
23
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'zed_pb2', _globals)
24
  if not _descriptor._USE_C_DESCRIPTORS:
25
+ DESCRIPTOR._loaded_options = None
26
+ _globals['_ERRORCODE']._serialized_start=1085
27
+ _globals['_ERRORCODE']._serialized_end=1452
28
+ _globals['_PEERID']._serialized_start=27
29
+ _globals['_PEERID']._serialized_end=65
30
+ _globals['_ENVELOPE']._serialized_start=68
31
+ _globals['_ENVELOPE']._serialized_end=806
32
+ _globals['_HELLO']._serialized_start=808
33
+ _globals['_HELLO']._serialized_end=854
34
+ _globals['_PING']._serialized_start=856
35
+ _globals['_PING']._serialized_end=862
36
+ _globals['_ACK']._serialized_start=864
37
+ _globals['_ACK']._serialized_end=869
38
+ _globals['_ERROR']._serialized_start=871
39
+ _globals['_ERROR']._serialized_end=948
40
+ _globals['_ACCEPTTERMSOFSERVICE']._serialized_start=950
41
+ _globals['_ACCEPTTERMSOFSERVICE']._serialized_end=972
42
+ _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_start=974
43
+ _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_end=1029
44
+ _globals['_GETLLMTOKEN']._serialized_start=1031
45
+ _globals['_GETLLMTOKEN']._serialized_end=1044
46
+ _globals['_GETLLMTOKENRESPONSE']._serialized_start=1046
47
+ _globals['_GETLLMTOKENRESPONSE']._serialized_end=1082
48
+
49
+ # Start of the actual script
50
  import os
51
  import json
52
  import ssl
53
  import time
54
  import asyncio
55
+ import logging
56
  import aiohttp
 
 
 
57
  from aiohttp import web
58
  import zstandard as zstd
59
+ from dotenv import load_dotenv
60
  from websockets.asyncio.client import connect
61
  from websockets.exceptions import ConnectionClosed
62
+ import uuid
63
 
64
  from google.protobuf.json_format import MessageToDict
65
 
 
66
  Envelope = _sym_db.GetSymbol('zed.messages.Envelope')
67
 
68
+ logging.basicConfig(
69
+ level=logging.INFO,
70
+ format='%(levelname)s: %(message)s'
71
+ )
72
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ current_dir = os.path.dirname(os.path.abspath(__file__))
75
+ env_path = os.path.join(current_dir, '.env')
 
76
 
77
+ load_dotenv(env_path)
78
 
79
+ CONFIG = {
80
+ "API":{
81
+ "BASE_URL": "https://zed.dev",
82
+ "API_KEY": os.getenv("API_KEY","sk-123456"),
83
+ "BASE_API_URL": "https://collab.zed.dev",
84
+ "WS_URL": "wss://collab.zed.dev/rpc",
85
+ "LLM_API_URL": "https://llm.zed.dev/completion",
86
+ },
87
+ "LOGIN":{
88
+ "USER_ID": os.getenv("ZED_USER_ID"),
89
+ "AUTH": os.getenv("ZED_AUTH_TOKEN")
90
+ },
91
+ "SERVER":{
92
+ "PORT": os.getenv("PORT",7860),
93
+ "TOKEN_EXPIRY_WARNING_MINUTES": 50
94
+ },
95
+ "MODELS":{
96
+ "claude-3-5-sonnet-20241022":"claude-3-5-sonnet-latest",
97
+ "claude-3-7-sonnet-20250219":"claude-3-7-sonnet-20250219"
98
+ }
99
+ }
100
 
101
  highest_message_id = 0
102
  llm_token = None
 
105
  active_websocket = None
106
  proxy_server_running = False
107
 
108
+ class MessageProcessor:
109
+ @staticmethod
110
+ def create_chat_response(message, model, is_stream=False):
111
+ base_response = {
112
+ "id": f"chatcmpl-{uuid.uuid4()}",
113
+ "created": int(time.time()),
114
+ "model": model
115
+ }
116
+
117
+ if is_stream:
118
+ return {
119
+ **base_response,
120
+ "object": "chat.completion.chunk",
121
+ "choices": [{
122
+ "index": 0,
123
+ "delta": {
124
+ "content": message
125
+ }
126
+ }]
127
+ }
128
+
129
+ return {
130
+ **base_response,
131
+ "object": "chat.completion",
132
+ "choices": [{
133
+ "index": 0,
134
+ "message": {
135
+ "role": "assistant",
136
+ "content": message
137
+ },
138
+ "finish_reason": "stop"
139
+ }],
140
+ "usage": None
141
+ }
 
 
 
 
 
 
 
142
 
143
  def decode_envelope(data):
144
  try:
 
155
  return MessageToDict(envelope, preserving_proto_field_name=True)
156
  except Exception as e:
157
  hex_preview = ' '.join(f'{byte:02x}' for byte in data[:20]) + ('...' if len(data) > 20 else '')
158
+ logger.error(f"无法解码消息: {e}; 数据预览: {hex_preview}")
159
+ return {"error": f"无法解码消息: {e}"}
160
 
161
  def compress_protobuf(data):
162
+ return zstd.ZstdCompressor(level=-7).compress(data)
163
 
164
  def create_message(message_type):
165
  global highest_message_id
 
176
  await websocket.ping()
177
  await asyncio.sleep(1)
178
  except Exception as e:
179
+ logger.error(f"发送ping错误: {e}")
180
  break
181
 
182
  async def handle_messages(websocket):
 
184
  active_websocket = websocket
185
  try:
186
  async for message in websocket:
187
+ message_bytes = message.encode('utf-8') if isinstance(message, str) else message
188
  decoded = decode_envelope(message_bytes)
189
  if "hello" in decoded:
190
  server_peer_id = decoded.get('hello', {}).get('peer_id')
191
  elif "accept_terms_of_service_response" in decoded:
192
  await request_llm_token(websocket)
193
+ elif ("get_llm_token_response" in decoded and
194
  'token' in decoded.get('get_llm_token_response', {})):
195
  llm_token = decoded['get_llm_token_response']['token']
196
  token_timestamp = time.time()
197
+ logger.info(f"LLM令牌收到 {time.ctime(token_timestamp)}")
198
  if not proxy_server_running:
199
  asyncio.create_task(start_proxy_server())
200
  asyncio.create_task(monitor_token_expiration())
201
+ logger.info("关闭WebSocket连接,直到需要刷新令牌")
202
  await websocket.close()
203
  active_websocket = None
204
  return
205
  except ConnectionClosed:
206
+ logger.info("连接已关闭")
207
  active_websocket = None
208
 
209
  async def request_llm_token(websocket):
210
  message, _ = create_message('get_llm_token')
211
+ logger.info("请求LLM令牌")
212
  await websocket.send(message)
213
 
214
  async def request_accept_terms_of_service(websocket):
215
  message, _ = create_message('accept_terms_of_service')
216
+ logger.info("发送同意Zed服务条款")
217
  await websocket.send(message)
218
 
219
  def format_content(content):
 
221
  return [{"type": "text", "text": content}]
222
  return content
223
 
224
+
225
+
226
+ async def process_message_content(content):
227
+ """
228
+ 处理消息内容,将不同类型的内容转换为字符串
229
+ """
230
+ if isinstance(content, str):
231
+ return content
232
+
233
+ if isinstance(content, list):
234
+ return '\n'.join([item.get('text', '') for item in content])
235
+
236
+ if isinstance(content, dict):
237
+ return content.get('text', None)
238
+
239
+ return None
240
+
241
+ async def transform_messages(request):
242
+ """
243
+ 转换消息格式,合并系统消息并处理消息结构
244
+ """
245
+ system_message = '' # 存储系统消息的变量
246
+ is_collecting_system_message = False # 是否正在收集系统消息
247
+ has_processed_system_messages = False # 是否已处理初始系统消息
248
+
249
+ converted_messages = []
250
+
251
+ for current in request.get('messages', []):
252
+ role = current.get('role')
253
+ current_content = await process_message_content(current.get('content'))
254
+
255
+ if current_content is None:
256
+ converted_messages.append(current)
257
+ continue
258
+
259
+ if role == 'system' and not has_processed_system_messages:
260
+ if not is_collecting_system_message:
261
+ # 第一次遇到system,开启收集
262
+ system_message = current_content
263
+ is_collecting_system_message = True
264
+ else:
265
+ # 继续遇到system,合并system消息
266
+ system_message += '\n' + current_content
267
+ continue
268
+
269
+ # 遇到非system消息
270
+ if is_collecting_system_message:
271
+ # 结束系统消息收集
272
+ is_collecting_system_message = False
273
+ has_processed_system_messages = True
274
+
275
+ # 如果已处理初始消息序列且再次遇到system,则转换role为user
276
+ if has_processed_system_messages and role == 'system':
277
+ role = 'user'
278
+
279
+ # 检查是否可以合并消息
280
+ if converted_messages and converted_messages[-1].get('role') == role:
281
+ converted_messages[-1]['content'][0]['text'] += '\r\n' + current_content
282
+ else:
283
+ converted_messages.append({
284
+ 'role': role,
285
+ 'content': [{'type': 'text', 'text': current_content}]
286
+ })
287
+
288
+ return {
289
+ 'messages': converted_messages,
290
+ 'system': system_message,
291
+ 'model': CONFIG['MODELS'].get(request.get('model'), "claude-3-5-sonnet-latest"),
292
+ 'max_tokens': request.get('max_tokens',8192),
293
+ 'temperature': max(0, min(request.get('temperature', 0), 1)),
294
+ 'top_p': max(0, min(request.get('top_p', 1), 1)),
295
+ 'top_k': max(0, min(request.get('top_k', 0), 500)),
296
+ 'stream': request.get('stream', False)
297
+ }
298
+
299
  @web.middleware
300
  async def auth_middleware(request, handler):
301
+ if CONFIG['API']['API_KEY']:
 
302
  auth_header = request.headers.get('Authorization')
303
+ xapi_key_header = request.headers.get('x-api-key')
304
 
 
305
  auth_password = None
306
  if auth_header and auth_header.startswith('Bearer '):
307
  auth_password = auth_header[7:]
308
 
309
+ if auth_password == CONFIG['API']['API_KEY'] or xapi_key_header == CONFIG['API']['API_KEY']:
 
310
  return await handler(request)
311
  else:
312
  return web.json_response(
313
+ {"error": "Unauthorized"},
314
  status=401
315
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ return await handler(request)
 
 
 
 
 
318
 
319
  async def handle_models_request(request):
 
320
  return web.json_response({
321
  "object": "list",
322
+ "data": [
323
+ {
324
+ "id": model,
325
+ "object": "model",
326
+ "created": int(time.time()),
327
+ "owned_by": "zed"
328
+ }
329
+ for model in CONFIG["MODELS"].keys()
330
+ ]
331
  })
332
 
333
  async def handle_message_request(request):
334
  global llm_token
335
  if not llm_token:
336
+ return web.json_response({"error": "LLM令牌不可用"}, status=500)
337
  try:
338
  body = await request.json()
339
+ isClaudeAI = False
340
+ if request.path == '/v1/messages':
341
+ isClaudeAI = True
342
+ if "messages" in body:
343
+ for msg in body["messages"]:
344
+ if "content" in msg:
345
+ msg["content"] = format_content(msg["content"])
346
+ if "system" in body:
347
+ if isinstance(body["system"], list):
348
+ body["system"] = "\n".join([item["text"] for item in body["system"]])
349
+ if "model" in body:
350
+ body["model"] = CONFIG['MODELS'].get(body["model"], "claude-3-5-sonnet-latest")
351
+ else:
352
+ body = await transform_messages(body)
 
 
 
 
353
 
354
  headers = {"Content-Type": "application/json", "Authorization": f"Bearer {llm_token}"}
355
  payload = {
356
  "provider": "anthropic",
357
+ "model": body.get("model", "claude-3-5-sonnet-latest"),
358
  "provider_request": body
359
  }
360
+ # with open('ceshi.txt', 'w', encoding='utf-8') as f:
361
+ # f.write(json.dumps(body,ensure_ascii=False) + '\n')
362
  if body.get("stream", False):
363
+ return await handle_streaming_request(request, headers, payload, isClaudeAI)
364
  else:
365
+ return await handle_non_streaming_request(headers, payload, isClaudeAI)
366
  except Exception as e:
367
+ logger.error(f"处理请求时发生错误: {e}")
368
  return web.json_response({"error": str(e)}, status=500)
369
 
370
+ async def handle_non_streaming_request(headers, payload, isClaudeAI=False):
371
  async with aiohttp.ClientSession() as session:
372
+ async with session.post(CONFIG['API']['LLM_API_URL'], headers=headers, json=payload) as r:
373
  if r.status != 200:
374
  text = await r.text()
375
+ logger.error(f"LLM API错误: {text}")
376
  return web.json_response({"error": text}, status=r.status)
377
  full_content, message_data = "", {}
378
  async for line in r.content:
379
  if not line:
380
  continue
381
  try:
382
+ event = json.loads(line.decode('utf-8').strip())
383
  et = event.get('type')
384
  if et == "message_start":
385
  message_data = event.get('message', {})
 
391
  break
392
  except Exception as e:
393
  logger.error(f"Error processing line: {e}")
394
+ if isClaudeAI:
395
+ message_data['content'] = [{"type": "text", "text": full_content}]
396
+ else:
397
+ message_data = MessageProcessor.create_chat_response(full_content, payload.get("model"), False)
398
+ return web.json_response(message_data)
399
 
400
+ async def handle_streaming_request(request, headers, payload, isClaudeAI=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  response = web.StreamResponse()
402
  response.headers['Content-Type'] = 'text/event-stream'
403
  response.headers['Cache-Control'] = 'no-cache'
404
  response.headers['Connection'] = 'keep-alive'
405
  await response.prepare(request)
406
+ logger.info(f"开始处理流请求")
407
  async with aiohttp.ClientSession() as session:
408
+ async with session.post(CONFIG['API']['LLM_API_URL'], headers=headers, json=payload) as api_response:
409
  if api_response.status != 200:
410
  error_text = await api_response.text()
411
+ logger.error(f"LLM API (stream)错误: {error_text}")
412
  await response.write(f"data: {json.dumps({'error': error_text})}\n\n".encode())
413
  await response.write(b"data: [DONE]\n\n")
414
  return response
 
 
 
 
 
 
 
415
  async for line in api_response.content:
416
+ try:
417
+ if line:
418
+ if isClaudeAI:
419
+ await response.write(f"data: {line.decode('utf-8')}\n\n".encode())
420
+ else:
421
+ try:
422
+ data = json.loads(line.decode('utf-8').strip())
423
+ print(data)
424
+ if data.get('type') == "content_block_delta" and data.get('delta', {}).get('type') == "text_delta":
425
+ text = data['delta'].get('text', '')
426
+ message = MessageProcessor.create_chat_response(text, payload.get("model"), True)
427
+ await response.write(f"data: {json.dumps(message)}\n\n".encode())
428
+ except Exception as e:
429
+ logger.error(f"Error processing line: {e}")
430
+ except Exception as e:
431
+ logger.error(f"Error processing line: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  await response.write(b"data: [DONE]\n\n")
433
+ return response
434
 
 
435
 
436
  async def start_proxy_server():
437
  global proxy_server_running
438
  if proxy_server_running:
439
+ logger.info("代理服务器已运行,跳过启动")
440
  return
441
+
442
  proxy_server_running = True
 
 
443
  app = web.Application(middlewares=[auth_middleware])
 
 
444
  app.router.add_post('/v1/messages', handle_message_request)
445
+ app.router.add_post('/v1/chat/completions', handle_message_request)
446
+ app.router.add_get('/v1/models', handle_models_request)
447
+
448
+ async def health_check():
 
449
  return web.json_response({
450
  "status": "ok",
451
+ "message": "Zed LLM proxy is running"
 
452
  })
453
 
454
  app.router.add_get('/', health_check)
455
 
456
  runner = web.AppRunner(app)
457
  await runner.setup()
458
+ site = web.TCPSite(runner, '0.0.0.0', CONFIG['SERVER']['PORT'])
459
  await site.start()
460
+ logger.info(f"代理服务器启动 http://0.0.0.0:{CONFIG['SERVER']['PORT']}")
 
 
 
 
 
 
 
 
 
461
  while True:
462
  await asyncio.sleep(3600)
463
 
464
  def is_token_expiring():
465
  if not token_timestamp:
466
  return False
467
+ return (time.time() - token_timestamp) / 60 >= CONFIG['SERVER']['TOKEN_EXPIRY_WARNING_MINUTES']
468
 
469
  async def monitor_token_expiration():
470
  while True:
471
  await asyncio.sleep(60)
472
  if is_token_expiring():
473
  elapsed = int((time.time() - token_timestamp) / 60)
474
+ logger.warning(f"LLM令牌接近过期 (收到 {elapsed} 分钟前)")
475
  if active_websocket is None:
476
+ logger.info("重新连接WebSocket以刷新令牌")
477
  asyncio.create_task(reconnect_for_token_refresh())
478
  return
479
 
480
  async def reconnect_for_token_refresh():
481
  try:
482
+ if not CONFIG['LOGIN']['USER_ID'] or not CONFIG['LOGIN']['AUTH']:
483
+ logger.error("用户ID或授权令牌未设置")
 
 
 
 
484
  return
 
 
485
  headers = {
486
+ "authorization": f"{CONFIG['LOGIN']['USER_ID']} {CONFIG['LOGIN']['AUTH']}",
487
  "x-zed-protocol-version": "68",
488
  "x-zed-app-version": "0.178.0",
489
  "x-zed-release-channel": "stable"
490
  }
 
491
  ssl_context = ssl.create_default_context()
492
  ssl_context.check_hostname = False
493
  ssl_context.verify_mode = ssl.CERT_NONE
494
+
495
+ async for websocket in connect(CONFIG['API']['WS_URL'], additional_headers=headers, ssl=ssl_context):
496
  try:
497
  ping_task = asyncio.create_task(ping_periodically(websocket))
498
  await asyncio.sleep(2)
 
502
  except ConnectionClosed:
503
  continue
504
  except Exception as e:
505
+ logger.error(f"令牌刷新期间发生错误: {e}")
506
  await asyncio.sleep(1)
507
  continue
508
  finally:
 
512
  except asyncio.CancelledError:
513
  pass
514
  except Exception as e:
515
+ logger.error(f"令牌刷新失败: {e}")
516
 
517
  async def async_main():
518
+ if not CONFIG['LOGIN']['USER_ID'] or not CONFIG['LOGIN']['AUTH']:
519
+ logger.error("用户ID或授权令牌未设置")
 
 
 
 
 
520
  return
 
 
521
  headers = {
522
+ "authorization": f"{CONFIG['LOGIN']['USER_ID']} {CONFIG['LOGIN']['AUTH']}",
523
  "x-zed-protocol-version": "68",
524
  "x-zed-app-version": "0.178.0",
525
  "x-zed-release-channel": "stable"
526
  }
 
527
  ssl_context = ssl.create_default_context()
528
  ssl_context.check_hostname = False
529
  ssl_context.verify_mode = ssl.CERT_NONE
530
+ logger.info("连接到Websocket服务器")
531
+ async for websocket in connect(CONFIG['API']['WS_URL'], additional_headers=headers, ssl=ssl_context):
532
  try:
533
  ping_task = asyncio.create_task(ping_periodically(websocket))
534
  token_request_task = asyncio.create_task(delayed_token_request(websocket, delay=2))
 
537
  except ConnectionClosed:
538
  continue
539
  except Exception as e:
540
+ logger.error(f"意外错误: {e}")
541
  await asyncio.sleep(1)
542
  continue
543
  finally:
 
559
  await asyncio.sleep(delay)
560
  await request_accept_terms_of_service(websocket)
561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  if __name__ == "__main__":
563
+ asyncio.run(async_main())