letterm commited on
Commit
6133e11
·
verified ·
1 Parent(s): cc65e0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +431 -277
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # Generated from trimmed zed.proto
2
  from google.protobuf import descriptor as _descriptor
3
  from google.protobuf import descriptor_pool as _descriptor_pool
4
  from google.protobuf import runtime_version as _runtime_version
@@ -22,77 +21,126 @@ _globals = globals()
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
23
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'zed_pb2', _globals)
24
  if not _descriptor._USE_C_DESCRIPTORS:
25
- DESCRIPTOR._loaded_options = None
26
- _globals['_ERRORCODE']._serialized_start=1085
27
- _globals['_ERRORCODE']._serialized_end=1452
28
- _globals['_PEERID']._serialized_start=27
29
- _globals['_PEERID']._serialized_end=65
30
- _globals['_ENVELOPE']._serialized_start=68
31
- _globals['_ENVELOPE']._serialized_end=806
32
- _globals['_HELLO']._serialized_start=808
33
- _globals['_HELLO']._serialized_end=854
34
- _globals['_PING']._serialized_start=856
35
- _globals['_PING']._serialized_end=862
36
- _globals['_ACK']._serialized_start=864
37
- _globals['_ACK']._serialized_end=869
38
- _globals['_ERROR']._serialized_start=871
39
- _globals['_ERROR']._serialized_end=948
40
- _globals['_ACCEPTTERMSOFSERVICE']._serialized_start=950
41
- _globals['_ACCEPTTERMSOFSERVICE']._serialized_end=972
42
- _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_start=974
43
- _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_end=1029
44
- _globals['_GETLLMTOKEN']._serialized_start=1031
45
- _globals['_GETLLMTOKEN']._serialized_end=1044
46
- _globals['_GETLLMTOKENRESPONSE']._serialized_start=1046
47
- _globals['_GETLLMTOKENRESPONSE']._serialized_end=1082
48
-
49
- # Start of the actual script
50
  import os
51
  import json
52
  import ssl
53
  import time
54
  import asyncio
55
- import logging
56
  import aiohttp
 
 
 
 
 
 
57
  from aiohttp import web
58
  import zstandard as zstd
59
  from websockets.asyncio.client import connect
60
  from websockets.exceptions import ConnectionClosed
61
- import uuid
62
 
63
  from google.protobuf.json_format import MessageToDict
64
 
 
65
  Envelope = _sym_db.GetSymbol('zed.messages.Envelope')
66
 
67
- logging.basicConfig(
68
- level=logging.INFO,
69
- format='%(levelname)s: %(message)s'
70
- )
71
- logger = logging.getLogger(__name__)
72
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
74
 
75
- CONFIG = {
76
- "API":{
77
- "BASE_URL": "https://zed.dev",
78
- "API_KEY": os.getenv("API_KEY","sk-123456"),
79
- "BASE_API_URL": "https://collab.zed.dev",
80
- "WS_URL": "wss://collab.zed.dev/rpc",
81
- "LLM_API_URL": "https://llm.zed.dev/completion",
82
- },
83
- "LOGIN":{
84
- "USER_ID": os.getenv("ZED_USER_ID"),
85
- "AUTH": os.getenv("ZED_AUTH_TOKEN")
86
- },
87
- "SERVER":{
88
- "PORT": os.getenv("PORT",5200),
89
- "TOKEN_EXPIRY_WARNING_MINUTES": 50
90
- },
91
- "MODELS":{
92
- "claude-3-5-sonnet-20241022":"claude-3-5-sonnet-latest",
93
- "claude-3-7-sonnet-20250219":"claude-3-7-sonnet-20250219"
94
- }
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  highest_message_id = 0
98
  llm_token = None
@@ -101,40 +149,47 @@ server_peer_id = None
101
  active_websocket = None
102
  proxy_server_running = False
103
 
104
- class MessageProcessor:
105
- @staticmethod
106
- def create_chat_response(message, model, is_stream=False):
107
- base_response = {
108
- "id": f"chatcmpl-{uuid.uuid4()}",
109
- "created": int(time.time()),
110
- "model": model
111
- }
112
-
113
- if is_stream:
114
- return {
115
- **base_response,
116
- "object": "chat.completion.chunk",
117
- "choices": [{
118
- "index": 0,
119
- "delta": {
120
- "content": message
121
- }
122
- }]
123
- }
124
-
125
- return {
126
- **base_response,
127
- "object": "chat.completion",
128
- "choices": [{
129
- "index": 0,
130
- "message": {
131
- "role": "assistant",
132
- "content": message
133
- },
134
- "finish_reason": "stop"
135
- }],
136
- "usage": None
137
- }
 
 
 
 
 
 
 
138
 
139
  def decode_envelope(data):
140
  try:
@@ -151,11 +206,11 @@ def decode_envelope(data):
151
  return MessageToDict(envelope, preserving_proto_field_name=True)
152
  except Exception as e:
153
  hex_preview = ' '.join(f'{byte:02x}' for byte in data[:20]) + ('...' if len(data) > 20 else '')
154
- logger.error(f"无法解码消息: {e}; 数据预览: {hex_preview}")
155
- return {"error": f"无法解码消息: {e}"}
156
 
157
  def compress_protobuf(data):
158
- return zstd.ZstdCompressor(level=-7).compress(data)
159
 
160
  def create_message(message_type):
161
  global highest_message_id
@@ -172,7 +227,7 @@ async def ping_periodically(websocket):
172
  await websocket.ping()
173
  await asyncio.sleep(1)
174
  except Exception as e:
175
- logger.error(f"发送ping错误: {e}")
176
  break
177
 
178
  async def handle_messages(websocket):
@@ -180,36 +235,36 @@ async def handle_messages(websocket):
180
  active_websocket = websocket
181
  try:
182
  async for message in websocket:
183
- message_bytes = message.encode('utf-8') if isinstance(message, str) else message
184
  decoded = decode_envelope(message_bytes)
185
  if "hello" in decoded:
186
  server_peer_id = decoded.get('hello', {}).get('peer_id')
187
  elif "accept_terms_of_service_response" in decoded:
188
  await request_llm_token(websocket)
189
- elif ("get_llm_token_response" in decoded and
190
  'token' in decoded.get('get_llm_token_response', {})):
191
  llm_token = decoded['get_llm_token_response']['token']
192
  token_timestamp = time.time()
193
- logger.info(f"LLM令牌收到 {time.ctime(token_timestamp)}")
194
  if not proxy_server_running:
195
  asyncio.create_task(start_proxy_server())
196
  asyncio.create_task(monitor_token_expiration())
197
- logger.info("关闭WebSocket连接,直到需要刷新令牌")
198
  await websocket.close()
199
  active_websocket = None
200
  return
201
  except ConnectionClosed:
202
- logger.info("连接已关闭")
203
  active_websocket = None
204
 
205
  async def request_llm_token(websocket):
206
  message, _ = create_message('get_llm_token')
207
- logger.info("请求LLM令牌")
208
  await websocket.send(message)
209
 
210
  async def request_accept_terms_of_service(websocket):
211
  message, _ = create_message('accept_terms_of_service')
212
- logger.info("发送同意Zed服务条款")
213
  await websocket.send(message)
214
 
215
  def format_content(content):
@@ -217,168 +272,129 @@ def format_content(content):
217
  return [{"type": "text", "text": content}]
218
  return content
219
 
220
-
221
-
222
- async def process_message_content(content):
223
- """
224
- 处理消息内容,将不同类型的内容转换为字符串
225
- """
226
- if isinstance(content, str):
227
- return content
228
-
229
- if isinstance(content, list):
230
- return '\n'.join([item.get('text', '') for item in content])
231
-
232
- if isinstance(content, dict):
233
- return content.get('text', None)
234
-
235
- return None
236
-
237
- async def transform_messages(request):
238
- """
239
- 转换消息格式,合并系统消息并处理消息结构
240
- """
241
- system_message = '' # 存储系统消息的变量
242
- is_collecting_system_message = False # 是否正在收集系统消息
243
- has_processed_system_messages = False # 是否已处理初始系统消息
244
-
245
- converted_messages = []
246
-
247
- for current in request.get('messages', []):
248
- role = current.get('role')
249
- current_content = await process_message_content(current.get('content'))
250
-
251
- if current_content is None:
252
- converted_messages.append(current)
253
- continue
254
-
255
- if role == 'system' and not has_processed_system_messages:
256
- if not is_collecting_system_message:
257
- # 第一次遇到system,开启收集
258
- system_message = current_content
259
- is_collecting_system_message = True
260
- else:
261
- # 继续遇到system,合并system消息
262
- system_message += '\n' + current_content
263
- continue
264
-
265
- # 遇到非system消息
266
- if is_collecting_system_message:
267
- # 结束系统消息收集
268
- is_collecting_system_message = False
269
- has_processed_system_messages = True
270
-
271
- # 如果已处理初始消息序列且再次遇到system,则转换role为user
272
- if has_processed_system_messages and role == 'system':
273
- role = 'user'
274
-
275
- # 检查是否可以合并消息
276
- if converted_messages and converted_messages[-1].get('role') == role:
277
- converted_messages[-1]['content'][0]['text'] += '\r\n' + current_content
278
- else:
279
- converted_messages.append({
280
- 'role': role,
281
- 'content': [{'type': 'text', 'text': current_content}]
282
- })
283
-
284
- return {
285
- 'messages': converted_messages,
286
- 'system': system_message,
287
- 'model': CONFIG['MODELS'].get(request.get('model'), "claude-3-5-sonnet-latest"),
288
- 'max_tokens': request.get('max_tokens',8192),
289
- 'temperature': max(0, min(request.get('temperature', 0), 1)),
290
- 'top_p': max(0, min(request.get('top_p', 1), 1)),
291
- 'top_k': max(0, min(request.get('top_k', 0), 500)),
292
- 'stream': True
293
- }
294
-
295
  @web.middleware
296
  async def auth_middleware(request, handler):
297
- if CONFIG['API']['API_KEY']:
 
298
  auth_header = request.headers.get('Authorization')
299
- xapi_key_header = request.headers.get('x-api-key')
300
 
 
301
  auth_password = None
302
  if auth_header and auth_header.startswith('Bearer '):
303
  auth_password = auth_header[7:]
304
 
305
- if auth_password == CONFIG['API']['API_KEY'] or xapi_key_header == CONFIG['API']['API_KEY']:
 
306
  return await handler(request)
307
  else:
308
  return web.json_response(
309
- {"error": "Unauthorized"},
310
  status=401
311
  )
312
-
313
  return await handler(request)
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  async def handle_models_request(request):
 
316
  return web.json_response({
317
  "object": "list",
318
- "data": [
319
- {
320
- "id": model,
321
- "object": "model",
322
- "created": int(time.time()),
323
- "owned_by": "zed"
324
- }
325
- for model in CONFIG["MODELS"].keys()
326
- ]
327
  })
328
 
329
  async def handle_message_request(request):
330
  global llm_token
331
  if not llm_token:
332
- return web.json_response({"error": "LLM令牌不可用"}, status=500)
333
  try:
334
  body = await request.json()
335
- isClaudeAI = False
336
- if request.path == '/v1/messages':
337
- isClaudeAI = True
338
- if "messages" in body:
339
- for msg in body["messages"]:
340
- if "content" in msg:
341
- msg["content"] = format_content(msg["content"])
342
- if "system" in body:
343
- if isinstance(body["system"], list):
344
- body["system"] = "\n".join([item["text"] for item in body["system"]])
345
- if "model" in body:
346
- body["model"] = CONFIG['MODELS'].get(body["model"], "claude-3-5-sonnet-latest")
347
- else:
348
- body = await transform_messages(body)
349
- with open('request_payload222.json', 'w', encoding='utf-8') as f:
350
- json.dump(body, f, ensure_ascii=False, indent=2)
 
 
 
351
  headers = {"Content-Type": "application/json", "Authorization": f"Bearer {llm_token}"}
352
- with open('ceshi.txt', 'w', encoding='utf-8') as f:
353
- f.write(llm_token + '\n')
354
  payload = {
355
  "provider": "anthropic",
356
- "model": body.get("model", "claude-3-5-sonnet-latest"),
357
  "provider_request": body
358
  }
359
- # with open('ceshi.txt', 'w', encoding='utf-8') as f:
360
- # f.write(json.dumps(body,ensure_ascii=False) + '\n')
361
  if body.get("stream", False):
362
- return await handle_streaming_request(request, headers, payload, isClaudeAI)
363
  else:
364
- return await handle_non_streaming_request(headers, payload, isClaudeAI)
365
  except Exception as e:
366
- logger.error(f"处理请求时发生错误: {e}")
367
  return web.json_response({"error": str(e)}, status=500)
368
 
369
- async def handle_non_streaming_request(headers, payload, isClaudeAI=False):
370
  async with aiohttp.ClientSession() as session:
371
- async with session.post(CONFIG['API']['LLM_API_URL'], headers=headers, json=payload) as r:
372
  if r.status != 200:
373
  text = await r.text()
374
- logger.error(f"LLM API错误: {text}")
375
  return web.json_response({"error": text}, status=r.status)
376
  full_content, message_data = "", {}
377
  async for line in r.content:
378
  if not line:
379
  continue
380
  try:
381
- event = json.loads(line.decode('utf-8').strip())
382
  et = event.get('type')
383
  if et == "message_start":
384
  message_data = event.get('message', {})
@@ -390,107 +406,195 @@ async def handle_non_streaming_request(headers, payload, isClaudeAI=False):
390
  break
391
  except Exception as e:
392
  logger.error(f"Error processing line: {e}")
393
- if isClaudeAI:
394
- message_data['content'] = [{"type": "text", "text": full_content}]
395
- else:
396
- message_data = MessageProcessor.create_chat_response(full_content, payload.get("model"), False)
397
- return web.json_response(message_data)
398
 
399
- async def handle_streaming_request(request, headers, payload, isClaudeAI=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  response = web.StreamResponse()
401
  response.headers['Content-Type'] = 'text/event-stream'
402
  response.headers['Cache-Control'] = 'no-cache'
403
  response.headers['Connection'] = 'keep-alive'
404
  await response.prepare(request)
405
- logger.info(f"开始处理流请求")
406
  async with aiohttp.ClientSession() as session:
407
- async with session.post(CONFIG['API']['LLM_API_URL'], headers=headers, json=payload) as api_response:
408
  if api_response.status != 200:
409
  error_text = await api_response.text()
410
- logger.error(f"LLM API (stream)错误: {error_text}")
411
  await response.write(f"data: {json.dumps({'error': error_text})}\n\n".encode())
412
  await response.write(b"data: [DONE]\n\n")
413
  return response
 
 
 
 
 
 
 
414
  async for line in api_response.content:
415
- try:
416
- if line:
417
- if isClaudeAI:
418
- await response.write(f"data: {line.decode('utf-8')}\n\n".encode())
419
- else:
420
- try:
421
- data = json.loads(line.decode('utf-8').strip())
422
- if data.get('type') == "content_block_delta" and data.get('delta', {}).get('type') == "text_delta":
423
- text = data['delta'].get('text', '')
424
- message = MessageProcessor.create_chat_response(text, payload.get("model"), True)
425
- await response.write(f"data: {json.dumps(message)}\n\n".encode())
426
- except Exception as e:
427
- logger.error(f"Error processing line: {e}")
428
- except Exception as e:
429
- logger.error(f"Error processing line: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  await response.write(b"data: [DONE]\n\n")
431
- return response
432
 
 
433
 
434
  async def start_proxy_server():
435
  global proxy_server_running
436
  if proxy_server_running:
437
- logger.info("代理服务器已运行,跳过启动")
438
  return
439
-
440
  proxy_server_running = True
 
 
441
  app = web.Application(middlewares=[auth_middleware])
 
 
442
  app.router.add_post('/v1/messages', handle_message_request)
443
- app.router.add_post('/v1/chat/completions', handle_message_request)
444
- app.router.add_get('/v1/models', handle_models_request)
445
-
446
- async def health_check():
 
447
  return web.json_response({
448
  "status": "ok",
449
- "message": "Zed LLM proxy is running"
 
450
  })
451
 
452
  app.router.add_get('/', health_check)
453
 
454
  runner = web.AppRunner(app)
455
  await runner.setup()
456
- site = web.TCPSite(runner, 'localhost', CONFIG['SERVER']['PORT'])
457
  await site.start()
458
- logger.info(f"代理服务器启动 http://localhost:{CONFIG['SERVER']['PORT']}")
 
 
 
 
 
 
 
 
 
459
  while True:
460
  await asyncio.sleep(3600)
461
 
462
  def is_token_expiring():
463
  if not token_timestamp:
464
  return False
465
- return (time.time() - token_timestamp) / 60 >= CONFIG['SERVER']['TOKEN_EXPIRY_WARNING_MINUTES']
466
 
467
  async def monitor_token_expiration():
468
  while True:
469
  await asyncio.sleep(60)
470
  if is_token_expiring():
471
  elapsed = int((time.time() - token_timestamp) / 60)
472
- logger.warning(f"LLM令牌接近过期 (收到 {elapsed} 分钟前)")
473
  if active_websocket is None:
474
- logger.info("重新连接WebSocket以刷新令牌")
475
  asyncio.create_task(reconnect_for_token_refresh())
476
  return
477
 
478
  async def reconnect_for_token_refresh():
479
  try:
480
- if not CONFIG['LOGIN']['USER_ID'] or not CONFIG['LOGIN']['AUTH']:
481
- logger.error("用户ID或授权令牌未设置")
 
 
 
 
482
  return
 
 
483
  headers = {
484
- "authorization": f"{CONFIG['LOGIN']['USER_ID']} {CONFIG['LOGIN']['AUTH']}",
485
  "x-zed-protocol-version": "68",
486
  "x-zed-app-version": "0.178.0",
487
  "x-zed-release-channel": "stable"
488
  }
 
489
  ssl_context = ssl.create_default_context()
490
  ssl_context.check_hostname = False
491
  ssl_context.verify_mode = ssl.CERT_NONE
492
-
493
- async for websocket in connect(CONFIG['API']['WS_URL'], additional_headers=headers, ssl=ssl_context):
494
  try:
495
  ping_task = asyncio.create_task(ping_periodically(websocket))
496
  await asyncio.sleep(2)
@@ -500,7 +604,7 @@ async def reconnect_for_token_refresh():
500
  except ConnectionClosed:
501
  continue
502
  except Exception as e:
503
- logger.error(f"令牌刷新期间发生错误: {e}")
504
  await asyncio.sleep(1)
505
  continue
506
  finally:
@@ -510,30 +614,80 @@ async def reconnect_for_token_refresh():
510
  except asyncio.CancelledError:
511
  pass
512
  except Exception as e:
513
- logger.error(f"令牌刷新失败: {e}")
514
 
515
  async def async_main():
516
- app = web.Application()
517
- app.router.add_post('/v1/messages', handle_message_request)
518
- app.router.add_post('/v1/chat/completions', handle_message_request)
519
- app.router.add_get('/v1/models', handle_models_request)
520
-
521
- async def health_check():
522
- return web.json_response({
523
- "status": "ok",
524
- "message": "Zed LLM proxy is running"
525
- })
526
 
527
- app.router.add_get('/', health_check)
 
 
 
528
 
529
- runner = web.AppRunner(app)
530
- await runner.setup()
531
- site = web.TCPSite(runner, 'localhost', CONFIG['SERVER']['PORT'])
532
- await site.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
 
534
  async def delayed_token_request(websocket, delay=2):
535
  await asyncio.sleep(delay)
536
  await request_accept_terms_of_service(websocket)
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  if __name__ == "__main__":
539
- asyncio.run(async_main())
 
 
1
  from google.protobuf import descriptor as _descriptor
2
  from google.protobuf import descriptor_pool as _descriptor_pool
3
  from google.protobuf import runtime_version as _runtime_version
 
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'zed_pb2', _globals)
23
  if not _descriptor._USE_C_DESCRIPTORS:
24
+ DESCRIPTOR._loaded_options = None
25
+ _globals['_ERRORCODE']._serialized_start=1085
26
+ _globals['_ERRORCODE']._serialized_end=1452
27
+ _globals['_PEERID']._serialized_start=27
28
+ _globals['_PEERID']._serialized_end=65
29
+ _globals['_ENVELOPE']._serialized_start=68
30
+ _globals['_ENVELOPE']._serialized_end=806
31
+ _globals['_HELLO']._serialized_start=808
32
+ _globals['_HELLO']._serialized_end=854
33
+ _globals['_PING']._serialized_start=856
34
+ _globals['_PING']._serialized_end=862
35
+ _globals['_ACK']._serialized_start=864
36
+ _globals['_ACK']._serialized_end=869
37
+ _globals['_ERROR']._serialized_start=871
38
+ _globals['_ERROR']._serialized_end=948
39
+ _globals['_ACCEPTTERMSOFSERVICE']._serialized_start=950
40
+ _globals['_ACCEPTTERMSOFSERVICE']._serialized_end=972
41
+ _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_start=974
42
+ _globals['_ACCEPTTERMSOFSERVICERESPONSE']._serialized_end=1029
43
+ _globals['_GETLLMTOKEN']._serialized_start=1031
44
+ _globals['_GETLLMTOKEN']._serialized_end=1044
45
+ _globals['_GETLLMTOKENRESPONSE']._serialized_start=1046
46
+ _globals['_GETLLMTOKENRESPONSE']._serialized_end=1082
47
+
48
+
49
  import os
50
  import json
51
  import ssl
52
  import time
53
  import asyncio
 
54
  import aiohttp
55
+ import sys
56
+ import inspect
57
+ from loguru import logger
58
+ from flask import Flask, request, jsonify, Response, stream_with_context
59
+ from flask_cors import CORS
60
+ from dotenv import load_dotenv
61
  from aiohttp import web
62
  import zstandard as zstd
63
  from websockets.asyncio.client import connect
64
  from websockets.exceptions import ConnectionClosed
 
65
 
66
  from google.protobuf.json_format import MessageToDict
67
 
68
+ # 修复
69
  Envelope = _sym_db.GetSymbol('zed.messages.Envelope')
70
 
71
+ class Logger:
72
+ def __init__(self, level="INFO", colorize=True, format=None):
73
+ logger.remove()
 
 
74
 
75
+ if format is None:
76
+ format = (
77
+ "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
78
+ "<level>{level: <8}</level> | "
79
+ "<cyan>{extra[filename]}</cyan>:<cyan>{extra[function]}</cyan>:<cyan>{extra[lineno]}</cyan> | "
80
+ "<level>{message}</level>"
81
+ )
82
 
83
+ logger.add(
84
+ sys.stderr,
85
+ level=level,
86
+ format=format,
87
+ colorize=colorize,
88
+ backtrace=True,
89
+ diagnose=True
90
+ )
91
 
92
+ self.logger = logger
93
+
94
+ def _get_caller_info(self):
95
+ frame = inspect.currentframe()
96
+ try:
97
+ caller_frame = frame.f_back.f_back
98
+ full_path = caller_frame.f_code.co_filename
99
+ function = caller_frame.f_code.co_name
100
+ lineno = caller_frame.f_lineno
101
+
102
+ filename = os.path.basename(full_path)
103
+
104
+ return {
105
+ 'filename': filename,
106
+ 'function': function,
107
+ 'lineno': lineno
108
+ }
109
+ finally:
110
+ del frame
111
+
112
+ def info(self, message, source="API"):
113
+ caller_info = self._get_caller_info()
114
+ self.logger.bind(**caller_info).info(f"[{source}] {message}")
115
+
116
+ def error(self, message, source="API"):
117
+ caller_info = self._get_caller_info()
118
+
119
+ if isinstance(message, Exception):
120
+ self.logger.bind(**caller_info).exception(f"[{source}] {str(message)}")
121
+ else:
122
+ self.logger.bind(**caller_info).error(f"[{source}] {message}")
123
+
124
+ def warning(self, message, source="API"):
125
+ caller_info = self._get_caller_info()
126
+ self.logger.bind(**caller_info).warning(f"[{source}] {message}")
127
+
128
+ def debug(self, message, source="API"):
129
+ caller_info = self._get_caller_info()
130
+ self.logger.bind(**caller_info).debug(f"[{source}] {message}")
131
+
132
+ async def request_logger(self, request):
133
+ caller_info = self._get_caller_info()
134
+ self.logger.bind(**caller_info).info(f"请求: {request.method} {request.path}", "Request")
135
+
136
+ logger = Logger(level="INFO")
137
+
138
+ BASE_URL = "https://zed.dev"
139
+ BASE_API_URL = "https://collab.zed.dev"
140
+ WS_URL = "wss://collab.zed.dev/rpc"
141
+ LLM_API_URL = "https://llm.zed.dev/completion"
142
+ DEFAULT_PROXY_PORT = 5200 # 默认端口,可通过环境变量覆盖
143
+ TOKEN_EXPIRY_WARNING_MINUTES = 50
144
 
145
  highest_message_id = 0
146
  llm_token = None
 
149
  active_websocket = None
150
  proxy_server_running = False
151
 
152
+ # 从环境变量读取API密码和自定义端口
153
+ PROXY_PASSWORD = os.environ.get("ZED_PROXY_PASSWORD")
154
+ PROXY_PORT = int(os.environ.get("ZED_PROXY_PORT", DEFAULT_PROXY_PORT))
155
+
156
+ # 支持的模型列表
157
+ AVAILABLE_MODELS = [
158
+ {
159
+ "id": "claude-3-5-sonnet-20240620",
160
+ "object": "model",
161
+ "created": 1719158400,
162
+ "owned_by": "anthropic",
163
+ "permission": [],
164
+ },
165
+ {
166
+ "id": "claude-3-5-sonnet-20241022",
167
+ "object": "model",
168
+ "created": 1729344000,
169
+ "owned_by": "anthropic",
170
+ "permission": [],
171
+ },
172
+ {
173
+ "id": "claude-3-7-sonnet-20250219",
174
+ "object": "model",
175
+ "created": 1740153600,
176
+ "owned_by": "anthropic",
177
+ "permission": [],
178
+ }
179
+ ]
180
+
181
+ # 模型映射表(从其他可能的模型名称映射到我们支持的模型)
182
+ MODEL_MAPPING = {
183
+ # 默认映射到最新版本
184
+ "claude-3-5-sonnet": "claude-3-5-sonnet-20241022",
185
+ "claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
186
+ "claude-3-7-sonnet": "claude-3-7-sonnet-20250219",
187
+ "claude-3.7-sonnet": "claude-3-7-sonnet-20250219",
188
+ # 确保标准名称也能映射到自身
189
+ "claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620",
190
+ "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-20241022",
191
+ "claude-3-7-sonnet-20250219": "claude-3-7-sonnet-20250219",
192
+ }
193
 
194
  def decode_envelope(data):
195
  try:
 
206
  return MessageToDict(envelope, preserving_proto_field_name=True)
207
  except Exception as e:
208
  hex_preview = ' '.join(f'{byte:02x}' for byte in data[:20]) + ('...' if len(data) > 20 else '')
209
+ logger.error(f"Unable to decode message: {e}; data preview: {hex_preview}")
210
+ return {"error": f"Unable to decode message: {e}"}
211
 
212
  def compress_protobuf(data):
213
+ return zstd.ZstdCompressor(level=7).compress(data)
214
 
215
  def create_message(message_type):
216
  global highest_message_id
 
227
  await websocket.ping()
228
  await asyncio.sleep(1)
229
  except Exception as e:
230
+ logger.error(f"Error sending ping: {e}")
231
  break
232
 
233
  async def handle_messages(websocket):
 
235
  active_websocket = websocket
236
  try:
237
  async for message in websocket:
238
+ message_bytes = message.encode('utf8') if isinstance(message, str) else message
239
  decoded = decode_envelope(message_bytes)
240
  if "hello" in decoded:
241
  server_peer_id = decoded.get('hello', {}).get('peer_id')
242
  elif "accept_terms_of_service_response" in decoded:
243
  await request_llm_token(websocket)
244
+ elif ("get_llm_token_response" in decoded and
245
  'token' in decoded.get('get_llm_token_response', {})):
246
  llm_token = decoded['get_llm_token_response']['token']
247
  token_timestamp = time.time()
248
+ logger.info(f"LLM token received at {time.ctime(token_timestamp)}")
249
  if not proxy_server_running:
250
  asyncio.create_task(start_proxy_server())
251
  asyncio.create_task(monitor_token_expiration())
252
+ logger.info("Closing WebSocket connection until token refresh is needed")
253
  await websocket.close()
254
  active_websocket = None
255
  return
256
  except ConnectionClosed:
257
+ logger.info("Connection closed")
258
  active_websocket = None
259
 
260
  async def request_llm_token(websocket):
261
  message, _ = create_message('get_llm_token')
262
+ logger.info("Requesting the LLM token")
263
  await websocket.send(message)
264
 
265
  async def request_accept_terms_of_service(websocket):
266
  message, _ = create_message('accept_terms_of_service')
267
+ logger.info("Sending consent for the Zed Terms of Service")
268
  await websocket.send(message)
269
 
270
  def format_content(content):
 
272
  return [{"type": "text", "text": content}]
273
  return content
274
 
275
+ # 密码验证中间件
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  @web.middleware
277
  async def auth_middleware(request, handler):
278
+ if PROXY_PASSWORD:
279
+ # 检查各种可能的密码验证方式
280
  auth_header = request.headers.get('Authorization')
281
+ password_header = request.headers.get('XPassword') or request.headers.get('Password')
282
 
283
+ # 从Authorization头获取密码(如果是Bearer格式)
284
  auth_password = None
285
  if auth_header and auth_header.startswith('Bearer '):
286
  auth_password = auth_header[7:]
287
 
288
+ # 检查任何来源的密码是否匹配
289
+ if auth_password == PROXY_PASSWORD or password_header == PROXY_PASSWORD:
290
  return await handler(request)
291
  else:
292
  return web.json_response(
293
+ {"error": "Unauthorized. Provide valid password in Authorization or XPassword header"},
294
  status=401
295
  )
296
+ # 如果未设置密码,直接通过验证
297
  return await handler(request)
298
 
299
+ def get_mapped_model(model_name):
300
+ """将各种模型名称映射到支持的模型"""
301
+ if not model_name:
302
+ return "claude-3-5-sonnet-20241022" # 默认模型
303
+
304
+ # 如果模型名称在映射表中,返回映射的值
305
+ return MODEL_MAPPING.get(model_name, "claude-3-5-sonnet-20241022")
306
+
307
+ def convert_openai_to_anthropic(data):
308
+ """将OpenAI API请求格式转换为Anthropic格式"""
309
+ # 如果已经是Anthropic格式,不做转换
310
+ if "messages" in data:
311
+ return data
312
+
313
+ result = {}
314
+
315
+ # 处理模型
316
+ if "model" in data:
317
+ result["model"] = get_mapped_model(data["model"])
318
+ else:
319
+ result["model"] = "claude-3-5-sonnet-20241022"
320
+
321
+ # 处理系统消息和用户消息
322
+ result["messages"] = []
323
+ if "system" in data:
324
+ result["system"] = data["system"]
325
+
326
+ # 处理消息数组
327
+ if "messages" in data:
328
+ result["messages"] = data["messages"]
329
+
330
+ # 复制其他参数
331
+ for key in ["temperature", "top_p", "max_tokens", "stream"]:
332
+ if key in data:
333
+ result[key] = data[key]
334
+
335
+ return result
336
+
337
  async def handle_models_request(request):
338
+ """处理/v1/models端点的请求,返回支持的模型列表"""
339
  return web.json_response({
340
  "object": "list",
341
+ "data": AVAILABLE_MODELS
 
 
 
 
 
 
 
 
342
  })
343
 
344
  async def handle_message_request(request):
345
  global llm_token
346
  if not llm_token:
347
+ return web.json_response({"error": "LLM token not available"}, status=500)
348
  try:
349
  body = await request.json()
350
+
351
+ # 转换格式(如果需要)
352
+ body = convert_openai_to_anthropic(body)
353
+
354
+ # 确保模型是支持的
355
+ if "model" in body:
356
+ body["model"] = get_mapped_model(body["model"])
357
+
358
+ # Zed的Anthropic API对消息内容格式有特定要求
359
+ if "messages" in body:
360
+ for msg in body["messages"]:
361
+ if "content" in msg:
362
+ msg["content"] = format_content(msg["content"])
363
+
364
+ # 处理系统消息格式
365
+ if "system" in body:
366
+ if isinstance(body["system"], list):
367
+ body["system"] = "\n".join([item["text"] for item in body["system"]])
368
+
369
  headers = {"Content-Type": "application/json", "Authorization": f"Bearer {llm_token}"}
 
 
370
  payload = {
371
  "provider": "anthropic",
372
+ "model": body.get("model", "claude-3-5-sonnet-20241022"),
373
  "provider_request": body
374
  }
375
+
376
+
377
  if body.get("stream", False):
378
+ return await handle_streaming_request(request, headers, payload)
379
  else:
380
+ return await handle_non_streaming_request(headers, payload)
381
  except Exception as e:
382
+ logger.error(f"Error processing request: {e}")
383
  return web.json_response({"error": str(e)}, status=500)
384
 
385
+ async def handle_non_streaming_request(headers, payload):
386
  async with aiohttp.ClientSession() as session:
387
+ async with session.post(LLM_API_URL, headers=headers, json=payload) as r:
388
  if r.status != 200:
389
  text = await r.text()
390
+ logger.error(f"LLM API error: {text}")
391
  return web.json_response({"error": text}, status=r.status)
392
  full_content, message_data = "", {}
393
  async for line in r.content:
394
  if not line:
395
  continue
396
  try:
397
+ event = json.loads(line.decode('utf8').strip())
398
  et = event.get('type')
399
  if et == "message_start":
400
  message_data = event.get('message', {})
 
406
  break
407
  except Exception as e:
408
  logger.error(f"Error processing line: {e}")
 
 
 
 
 
409
 
410
+ # 构建响应
411
+ message_data['content'] = [{"type": "text", "text": full_content}]
412
+
413
+ # 为OpenAI格式添加兼容字段 - 使用完整模型名称
414
+ model_name = payload.get("provider_request", {}).get("model", "claude-3-5-sonnet-20241022")
415
+ result = {
416
+ "id": f"chatcmpl{int(time.time()*1000)}",
417
+ "object": "chat.completion",
418
+ "created": int(time.time()),
419
+ "model": model_name,
420
+ "choices": [
421
+ {
422
+ "index": 0,
423
+ "message": {
424
+ "role": "assistant",
425
+ "content": full_content
426
+ },
427
+ "finish_reason": "stop"
428
+ }
429
+ ],
430
+ }
431
+
432
+ # 如果有usage信息,添加到结果中
433
+ if "usage" in message_data:
434
+ result["usage"] = message_data["usage"]
435
+
436
+ return web.json_response(result)
437
+
438
+ async def handle_streaming_request(request, headers, payload):
439
  response = web.StreamResponse()
440
  response.headers['Content-Type'] = 'text/event-stream'
441
  response.headers['Cache-Control'] = 'no-cache'
442
  response.headers['Connection'] = 'keep-alive'
443
  await response.prepare(request)
444
+
445
  async with aiohttp.ClientSession() as session:
446
+ async with session.post(LLM_API_URL, headers=headers, json=payload, timeout=60) as api_response:
447
  if api_response.status != 200:
448
  error_text = await api_response.text()
449
+ logger.error(f"LLM API (stream) error: {error_text}")
450
  await response.write(f"data: {json.dumps({'error': error_text})}\n\n".encode())
451
  await response.write(b"data: [DONE]\n\n")
452
  return response
453
+
454
+ # 为流式响应创建唯一ID
455
+ response_id = f"chatcmpl{int(time.time()*1000)}"
456
+ content_buffer = ""
457
+ # 使用完整模型名称
458
+ model_name = payload.get("provider_request", {}).get("model", "claude-3-5-sonnet-20241022")
459
+
460
  async for line in api_response.content:
461
+ if line:
462
+ line_text = line.decode('utf8')
463
+
464
+ # 直接传递原始事件
465
+ await response.write(f"data: {line_text}\n\n".encode())
466
+
467
+ # 解析事件以构建OpenAI兼容格式
468
+ try:
469
+ data = json.loads(line_text)
470
+ if data.get('type') == 'content_block_delta' and data.get('delta', {}).get('type') == 'text_delta':
471
+ delta_text = data['delta'].get('text', '')
472
+ content_buffer += delta_text
473
+
474
+ # 创建OpenAI兼容的流式响应格式
475
+ openai_chunk = {
476
+ "id": response_id,
477
+ "object": "chat.completion.chunk",
478
+ "created": int(time.time()),
479
+ "model": model_name,
480
+ "choices": [
481
+ {
482
+ "index": 0,
483
+ "delta": {
484
+ "content": delta_text
485
+ },
486
+ "finish_reason": None
487
+ }
488
+ ]
489
+ }
490
+ # 发送OpenAI兼容格式
491
+ await response.write(f"data: {json.dumps(openai_chunk)}\n\n".encode())
492
+
493
+ # 消息结束时发送完成标记
494
+ elif data.get('type') == 'message_stop':
495
+ final_chunk = {
496
+ "id": response_id,
497
+ "object": "chat.completion.chunk",
498
+ "created": int(time.time()),
499
+ "model": model_name,
500
+ "choices": [
501
+ {
502
+ "index": 0,
503
+ "delta": {},
504
+ "finish_reason": "stop"
505
+ }
506
+ ]
507
+ }
508
+ await response.write(f"data: {json.dumps(final_chunk)}\n\n".encode())
509
+ except json.JSONDecodeError:
510
+ pass
511
+
512
  await response.write(b"data: [DONE]\n\n")
 
513
 
514
+ return response
515
 
516
  async def start_proxy_server():
517
  global proxy_server_running
518
  if proxy_server_running:
519
+ logger.info("Proxy server already running, skipping startup")
520
  return
521
+
522
  proxy_server_running = True
523
+
524
+ # 创建带有中间件的应用
525
  app = web.Application(middlewares=[auth_middleware])
526
+
527
+ # 添加API端点
528
  app.router.add_post('/v1/messages', handle_message_request)
529
+ app.router.add_post('/v1/chat/completions', handle_message_request) # 添加OpenAI兼容端点
530
+ app.router.add_get('/v1/models', handle_models_request) # 添加模型列表端点
531
+
532
+ # 添加一个简单的健康检查端点
533
+ async def health_check(request):
534
  return web.json_response({
535
  "status": "ok",
536
+ "message": "Zed LLM proxy is running",
537
+ "available_models": [model["id"] for model in AVAILABLE_MODELS]
538
  })
539
 
540
  app.router.add_get('/', health_check)
541
 
542
  runner = web.AppRunner(app)
543
  await runner.setup()
544
+ site = web.TCPSite(runner, 'localhost', PROXY_PORT)
545
  await site.start()
546
+
547
+ password_status = "启用" if PROXY_PASSWORD else "未启用"
548
+ logger.info(f"代理服务器启动于 http://localhost:{PROXY_PORT}")
549
+ logger.info(f"密码保护: {password_status}")
550
+ logger.info(f"支持的模型: {', '.join([model['id'] for model in AVAILABLE_MODELS])}")
551
+ logger.info("支持的端点:")
552
+ logger.info(" /v1/messages (Anthropic API)")
553
+ logger.info(" /v1/chat/completions (OpenAI兼容API)")
554
+ logger.info(" /v1/models (模型列表)")
555
+
556
  while True:
557
  await asyncio.sleep(3600)
558
 
559
  def is_token_expiring():
560
  if not token_timestamp:
561
  return False
562
+ return (time.time() - token_timestamp) / 60 >= TOKEN_EXPIRY_WARNING_MINUTES
563
 
564
  async def monitor_token_expiration():
565
  while True:
566
  await asyncio.sleep(60)
567
  if is_token_expiring():
568
  elapsed = int((time.time() - token_timestamp) / 60)
569
+ logger.warning(f"LLM token is approaching expiration (received {elapsed} minutes ago)")
570
  if active_websocket is None:
571
+ logger.info("Reconnecting WebSocket for token refresh")
572
  asyncio.create_task(reconnect_for_token_refresh())
573
  return
574
 
575
  async def reconnect_for_token_refresh():
576
  try:
577
+ # 使用环境变量而不是从文件读取
578
+ user_id = os.environ.get("ZED_USER_ID")
579
+ auth_token = os.environ.get("ZED_AUTH_TOKEN").replace("\\", "")
580
+
581
+ if not user_id or not auth_token:
582
+ logger.error("环境变量ZED_USER_ID或ZED_AUTH_TOKEN未设置")
583
  return
584
+
585
+ # 确保header格式一致
586
  headers = {
587
+ "authorization": user_id + " " + auth_token,
588
  "x-zed-protocol-version": "68",
589
  "x-zed-app-version": "0.178.0",
590
  "x-zed-release-channel": "stable"
591
  }
592
+ print(headers)
593
  ssl_context = ssl.create_default_context()
594
  ssl_context.check_hostname = False
595
  ssl_context.verify_mode = ssl.CERT_NONE
596
+
597
+ async for websocket in connect(WS_URL, additional_headers=headers, ssl=ssl_context):
598
  try:
599
  ping_task = asyncio.create_task(ping_periodically(websocket))
600
  await asyncio.sleep(2)
 
604
  except ConnectionClosed:
605
  continue
606
  except Exception as e:
607
+ logger.error(f"Error during token refresh: {e}")
608
  await asyncio.sleep(1)
609
  continue
610
  finally:
 
614
  except asyncio.CancelledError:
615
  pass
616
  except Exception as e:
617
+ logger.error(f"Failed to reconnect for token refresh: {e}")
618
 
619
  async def async_main():
620
+ # 从环境变量获取认证信息
621
+ user_id = os.environ.get("ZED_USER_ID")
622
+ auth_token = os.environ.get("ZED_AUTH_TOKEN").replace("\\", "")
 
 
 
 
 
 
 
623
 
624
+ if not user_id or not auth_token:
625
+ logger.error("环境变量ZED_USER_ID或ZED_AUTH_TOKEN未设置")
626
+ logger.error("请先运行auth_script.py获取认证信息,然后设置环境变量")
627
+ return
628
 
629
+ # 确保header格式一致
630
+ headers = {
631
+ "authorization": user_id + " " + auth_token,
632
+ "x-zed-protocol-version": "68",
633
+ "x-zed-app-version": "0.178.0",
634
+ "x-zed-release-channel": "stable"
635
+ }
636
+ print(headers)
637
+ ssl_context = ssl.create_default_context()
638
+ ssl_context.check_hostname = False
639
+ ssl_context.verify_mode = ssl.CERT_NONE
640
+ logger.info("Connecting to the WebSocket server")
641
+ async for websocket in connect(WS_URL, additional_headers=headers, ssl=ssl_context):
642
+ try:
643
+ ping_task = asyncio.create_task(ping_periodically(websocket))
644
+ token_request_task = asyncio.create_task(delayed_token_request(websocket, delay=2))
645
+ await handle_messages(websocket)
646
+ break
647
+ except ConnectionClosed:
648
+ continue
649
+ except Exception as e:
650
+ logger.error(f"Unexpected error: {e}")
651
+ await asyncio.sleep(1)
652
+ continue
653
+ finally:
654
+ ping_task.cancel()
655
+ try:
656
+ await ping_task
657
+ except asyncio.CancelledError:
658
+ pass
659
+ token_request_task.cancel()
660
+ try:
661
+ await token_request_task
662
+ except asyncio.CancelledError:
663
+ pass
664
+
665
+ while True:
666
+ await asyncio.sleep(3600)
667
 
668
  async def delayed_token_request(websocket, delay=2):
669
  await asyncio.sleep(delay)
670
  await request_accept_terms_of_service(websocket)
671
 
672
+ def main():
673
+ logger.info("启动Zed LLM代理服务")
674
+ logger.info(f"配置端口: {PROXY_PORT}")
675
+ logger.info(f"密码保护: {'已启用' if PROXY_PASSWORD else '未启用'}")
676
+
677
+ # 检查必要的环境变量
678
+ if not os.environ.get("ZED_USER_ID") or not os.environ.get("ZED_AUTH_TOKEN"):
679
+ logger.error("错误: 环境变量ZED_USER_ID和ZED_AUTH_TOKEN必须设置")
680
+ logger.error("请先运行auth_script.py获取认证信息,然后设置环境变量")
681
+ return
682
+
683
+ try:
684
+ asyncio.run(async_main())
685
+ except KeyboardInterrupt:
686
+ logger.info("收到中断信号,正在退出...")
687
+ except Exception as e:
688
+ logger.error(f"发生错误: {e}")
689
+ finally:
690
+ logger.info("代理服务已关闭")
691
+
692
  if __name__ == "__main__":
693
+ main()