letterm commited on
Commit
6a5a959
·
verified ·
1 Parent(s): 194d709

Update model_mapper.py

Browse files
Files changed (1) hide show
  1. model_mapper.py +47 -305
model_mapper.py CHANGED
@@ -1,309 +1,51 @@
1
  """
2
- API服务模块
3
- 处理所有API请求逻辑,包括聊天完成和Token管理
4
  """
5
- import json
6
- import time
7
- from datetime import datetime
8
- from typing import List, Dict, Any, Generator
9
- from loguru import logger
10
 
11
- from config import Config
12
- from utils import Utils
13
- from token_manager import MultiTokenManager
14
- from warp_client import WarpClient
15
- from request_converter import RequestConverter
16
- from model_mapper import ModelMapper
17
 
18
-
19
- class ApiService:
20
- """API服务类,处理所有业务逻辑"""
21
-
22
- def __init__(self):
23
- # 初始化Token管理器
24
- self.token_manager = MultiTokenManager()
25
-
26
- # 初始化Warp客户端
27
- self.warp_client = WarpClient(self.token_manager)
28
-
29
- logger.info("🚀 ApiService初始化完成")
30
-
31
- def authenticate_request(self, auth_header: str) -> bool:
32
- """验证API请求"""
33
- if not auth_header:
34
- return False
35
-
36
- token = Utils.extract_bearer_token(auth_header)
37
- if not token:
38
- return False
39
-
40
- return Utils.validate_api_key(token)
41
-
42
- def get_models(self) -> Dict[str, Any]:
43
- """获取支持的模型列表"""
44
- models = []
45
- for model_name in ModelMapper.get_available_models():
46
- models.append({
47
- "id": model_name,
48
- "object": "model",
49
- "created": Utils.get_current_timestamp(),
50
- "owned_by": "warp-proxy"
51
- })
52
-
53
- return {
54
- "object": "list",
55
- "data": models
56
- }
57
-
58
- def chat_completion(self, request_data: Dict[str, Any], stream: bool = False) -> Generator[str, None, None]:
59
- """处理聊天完成请求"""
60
- request_id = Utils.generate_request_id()
61
- # 解析请求
62
- openai_request = RequestConverter.parse_openai_request(request_data)
63
- model = openai_request.model
64
- messages = openai_request.messages
65
-
66
- logger.info(f"🎯 开始处理聊天请求 [ID: {request_id[:8]}] [模型: {model}] [流式: {stream}]")
67
- start_time = time.time()
68
-
69
- try:
70
- # 创建protobuf数据
71
- protobuf_data = self.warp_client.create_protobuf_data(messages, model)
72
- if not protobuf_data:
73
- error_msg = "创建请求数据失败"
74
- logger.error(f"❌ {error_msg} [ID: {request_id[:8]}]")
75
- yield self._create_error_response(error_msg, request_id)
76
- return
77
-
78
- # 发送请求并处理响应
79
- response_chunks = 0
80
- total_content = ""
81
-
82
- logger.success(f"🚀 开始接收响应 [ID: {request_id[:8]}]")
83
-
84
- for chunk_text in self.warp_client.send_request(protobuf_data):
85
- if chunk_text:
86
- response_chunks += 1
87
- total_content += chunk_text
88
-
89
- logger.debug(f"📦 响应块 #{response_chunks} [ID: {request_id[:8]}] [长度: {len(chunk_text)}]")
90
-
91
- if stream:
92
- # 流式响应
93
- chunk_response = self._create_stream_chunk(chunk_text, request_id)
94
- yield f"data: {json.dumps(chunk_response)}\n\n"
95
- else:
96
- # 非流式响应 - 等待完整内容
97
- continue
98
-
99
- # 处理响应结束
100
- end_time = time.time()
101
- duration = end_time - start_time
102
-
103
- if stream:
104
- # 发送结束标记
105
- final_chunk = self._create_stream_end_chunk(request_id)
106
- yield f"data: {json.dumps(final_chunk)}\n\n"
107
- yield "data: [DONE]\n\n"
108
-
109
- logger.success(f"✅ 流式响应完成 [ID: {request_id[:8]}] [块数: {response_chunks}] [耗时: {duration:.2f}s]")
110
- else:
111
- # 返回完整响应
112
- response = self._create_complete_response(total_content, request_id)
113
- yield response
114
-
115
- logger.success(f"✅ 完整响应完成 [ID: {request_id[:8]}] [长度: {len(total_content)}] [耗时: {duration:.2f}s]")
116
-
117
- except Exception as e:
118
- logger.error(f"❌ 聊天请求处理失败 [ID: {request_id[:8]}]: {e}")
119
- yield self._create_error_response(f"服务器内部错误: {str(e)}", request_id)
120
-
121
- def _create_stream_chunk(self, content: str, request_id: str) -> Dict[str, Any]:
122
- """创建流式响应块"""
123
- return {
124
- "id": f"chatcmpl-{request_id}",
125
- "object": "chat.completion.chunk",
126
- "created": int(time.time()),
127
- "model": "gemini-2.0-flash",
128
- "choices": [{
129
- "index": 0,
130
- "delta": {"content": content},
131
- "finish_reason": None
132
- }]
133
- }
134
-
135
- def _create_stream_end_chunk(self, request_id: str) -> Dict[str, Any]:
136
- """创建流式响应结束块"""
137
- return {
138
- "id": f"chatcmpl-{request_id}",
139
- "object": "chat.completion.chunk",
140
- "created": int(time.time()),
141
- "model": "gemini-2.0-flash",
142
- "choices": [{
143
- "index": 0,
144
- "delta": {},
145
- "finish_reason": "stop"
146
- }]
147
- }
148
-
149
- def _create_complete_response(self, content: str, request_id: str) -> Dict[str, Any]:
150
- """创建完整响应"""
151
- return {
152
- "id": f"chatcmpl-{request_id}",
153
- "object": "chat.completion",
154
- "created": int(time.time()),
155
- "model": "gemini-2.0-flash",
156
- "choices": [{
157
- "index": 0,
158
- "message": {
159
- "role": "assistant",
160
- "content": content
161
- },
162
- "finish_reason": "stop"
163
- }],
164
- "usage": {
165
- "prompt_tokens": 0,
166
- "completion_tokens": 0,
167
- "total_tokens": 0
168
- }
169
- }
170
-
171
- def _create_error_response(self, error_message: str, request_id: str) -> Dict[str, Any]:
172
- """创建错误响应"""
173
- return {
174
- "error": {
175
- "message": error_message,
176
- "type": "api_error",
177
- "code": "internal_error"
178
- },
179
- "id": request_id
180
- }
181
-
182
- def get_token_status(self) -> Dict[str, Any]:
183
- """获取Token状态"""
184
- try:
185
- status = self.token_manager.get_token_status()
186
- return {"success": True, **status}
187
- except Exception as e:
188
- logger.error(f"❌ 获取Token状态失败: {e}")
189
- return {"success": False, "message": str(e)}
190
-
191
- def add_tokens(self, tokens: List[str]) -> Dict[str, Any]:
192
- """添加Token"""
193
- try:
194
- success = self.token_manager.add_refresh_tokens(tokens)
195
- if success:
196
- valid_tokens = [t for t in tokens if Utils.validate_refresh_token_format(t)]
197
- return {
198
- "success": True,
199
- "message": "Token添加成功",
200
- "added_tokens": len(valid_tokens)
201
- }
202
- else:
203
- return {"success": False, "message": "没有有效的Token可添加"}
204
- except Exception as e:
205
- logger.error(f"❌ 添加Token失败: {e}")
206
- return {"success": False, "message": str(e)}
207
-
208
- def remove_refresh_token(self, refresh_token: str) -> Dict[str, Any]:
209
- """删除refresh token"""
210
- try:
211
- success = self.token_manager.remove_refresh_token(refresh_token)
212
- if success:
213
- return {"success": True, "message": "Token删除成功"}
214
- else:
215
- return {"success": False, "message": "Token不存在"}
216
- except Exception as e:
217
- logger.error(f"❌ 删除Token失败: {e}")
218
- return {"success": False, "message": str(e)}
219
-
220
- def refresh_all_tokens(self) -> Dict[str, Any]:
221
- """刷新所有Token"""
222
- try:
223
- self.token_manager.refresh_all_tokens()
224
- return {"success": True, "message": "Token刷新已开始"}
225
- except Exception as e:
226
- logger.error(f"❌ 刷新Token失败: {e}")
227
- return {"success": False, "message": str(e)}
228
-
229
- def export_refresh_tokens(self, super_admin_key: str) -> Dict[str, Any]:
230
- """导出refresh token内容(需要超级管理员密钥验证)"""
231
- try:
232
- # 验证超级管理员密钥
233
- if Config.require_super_admin_auth():
234
- if not super_admin_key or super_admin_key != Config.get_super_admin_key():
235
- return {"success": False, "message": "超级管理员密钥验证失败"}
236
-
237
- # 获取所有refresh token
238
- with self.token_manager.token_lock:
239
- refresh_tokens = list(self.token_manager.tokens.keys())
240
-
241
- if not refresh_tokens:
242
- return {"success": False, "message": "没有可导出的token"}
243
-
244
- # 创建分号分割的token字符串
245
- token_string = ";".join(refresh_tokens)
246
-
247
- # 生成建议的文件名(带时间戳)
248
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
249
- suggested_filename = f"refresh_tokens_export_{timestamp}.txt"
250
-
251
- logger.info(f"🔒 超级管理员请求导出 {len(refresh_tokens)} 个refresh token")
252
-
253
- return {
254
- "success": True,
255
- "message": f"准备导出 {len(refresh_tokens)} 个token",
256
- "content": token_string,
257
- "suggested_filename": suggested_filename,
258
- "token_count": len(refresh_tokens)
259
- }
260
-
261
- except Exception as e:
262
- logger.error(f"❌ 准备导出refresh token失败: {e}")
263
- return {"success": False, "message": f"导出失败: {str(e)}"}
264
-
265
- def batch_get_refresh_tokens(self, email_url_dict: Dict[str, str], max_workers: int = 5) -> Dict[str, Any]:
266
- """批量获取refresh token并自动创建用户"""
267
- try:
268
- from login_client import LoginClient
269
- login_client = LoginClient()
270
-
271
- # 传递token_manager参数,这样在获取refresh_token后会立即尝试创建用户
272
- results = login_client.batch_process_emails(email_url_dict, max_workers, self.token_manager)
273
-
274
- # 提取有效的token并添加到管理器
275
- valid_tokens = []
276
- for email, result in results.items():
277
- if result.get('refresh_token'):
278
- valid_tokens.append(result['refresh_token'])
279
-
280
- if valid_tokens:
281
- self.token_manager.add_refresh_tokens(valid_tokens)
282
- logger.info(f"✅ 批量获取并添加了 {len(valid_tokens)} 个有效token")
283
-
284
- return {
285
- 'success': True,
286
- 'results': results,
287
- 'total_count': len(email_url_dict),
288
- 'success_count': len(valid_tokens)
289
- }
290
-
291
- except Exception as e:
292
- logger.error(f"❌ 批量获取refresh token失败: {e}")
293
- return {'success': False, 'message': str(e)}
294
-
295
- def start_services(self):
296
- """启动后台服务"""
297
- try:
298
- self.token_manager.start_auto_refresh()
299
- logger.success("✅ 后台服务启动成功")
300
- except Exception as e:
301
- logger.error(f"❌ 启动后台服务失败: {e}")
302
-
303
- def stop_services(self):
304
- """停止后台服务"""
305
- try:
306
- self.token_manager.stop_auto_refresh()
307
- logger.info("⏹️ 后台服务已停止")
308
- except Exception as e:
309
- logger.error(f"❌ 停止后台服务失败: {e}")
 
1
  """
2
+ 模型映射模块
3
+ 管理OpenAI模型名称到Warp模型名称的映射
4
  """
5
+ from typing import List
 
 
 
 
6
 
 
 
 
 
 
 
7
 
8
+ class ModelMapper:
9
+ """模型名称映射管理"""
10
+
11
+ MODEL_MAPPING = {
12
+ "claude-sonnet-4-20250514": "claude-4-sonnet",
13
+ "claude-3-7-sonnet-20250219": "claude-3-7-sonnet",
14
+ "claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
15
+ "claude-3-5-haiku-20241022": "claude-3-5-haiku",
16
+ "gpt-4o": "gpt-4o",
17
+ "gpt-4.1": "gpt-4.1",
18
+ "o4-mini": "o4-mini",
19
+ "o3": "o3",
20
+ "o3-mini": "o3-mini",
21
+ "gemini-2.0-flash": "gemini-2.0-flash",
22
+ "gemini-2.5-pro": "gemini-2.5-pro"
23
+ }
24
+
25
+ DEFAULT_MODEL = "gemini-2.0-flash"
26
+
27
+ @classmethod
28
+ def get_warp_model(cls, openai_model: str) -> str:
29
+ """将OpenAI模型名转换为Warp模型名"""
30
+ return cls.MODEL_MAPPING.get(openai_model, cls.DEFAULT_MODEL)
31
+
32
+ @classmethod
33
+ def get_available_models(cls) -> List[str]:
34
+ """获取所有可用模型列表"""
35
+ return list(cls.MODEL_MAPPING.keys())
36
+
37
+ @classmethod
38
+ def is_valid_model(cls, model: str) -> bool:
39
+ """检查模型是否有效"""
40
+ return model in cls.MODEL_MAPPING
41
+
42
+ @classmethod
43
+ def add_model_mapping(cls, openai_model: str, warp_model: str):
44
+ """添加新的模型映射"""
45
+ cls.MODEL_MAPPING[openai_model] = warp_model
46
+
47
+ @classmethod
48
+ def remove_model_mapping(cls, openai_model: str):
49
+ """移除模型映射"""
50
+ if openai_model in cls.MODEL_MAPPING:
51
+ del cls.MODEL_MAPPING[openai_model]