letterm commited on
Commit
194d709
·
verified ·
1 Parent(s): 197434d

Update model_mapper.py

Browse files
Files changed (1) hide show
  1. model_mapper.py +305 -47
model_mapper.py CHANGED
@@ -1,51 +1,309 @@
1
  """
2
- 模型映射模块
3
- 管理OpenAI模型名称到Warp模型名称的映射
4
  """
5
- from typing import List
 
 
 
 
6
 
 
 
 
 
 
 
7
 
8
- class ModelMapper:
9
- """模型名称映射管理"""
10
-
11
- MODEL_MAPPING = {
12
- "claude-sonnet-4-20250514": "claude-4-sonnet",
13
- "claude-3-7-sonnet-20250219": "claude-3-7-sonnet",
14
- "claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
15
- "claude-3-5-haiku-20241022": "claude-3-5-haiku",
16
- "gpt-4o": "gpt-4o",
17
- "gpt-4.1": "gpt-4.1",
18
- "o4-mini": "o4-mini",
19
- "o3": "o3",
20
- "o3-mini": "o3-mini",
21
- "gemini-2.0-flash": "gemini-2.0-flash",
22
- "gemini-2.5-pro": "gemini-2.5-pro"
23
- }
24
-
25
- DEFAULT_MODEL = "gemini-2.0-flash"
26
-
27
- @classmethod
28
- def get_warp_model(cls, openai_model: str) -> str:
29
- """将OpenAI模型名转换为Warp模型名"""
30
- return cls.MODEL_MAPPING.get(openai_model, cls.DEFAULT_MODEL)
31
-
32
- @classmethod
33
- def get_available_models(cls) -> List[str]:
34
- """获取所有可用模型列表"""
35
- return list(cls.MODEL_MAPPING.keys())
36
-
37
- @classmethod
38
- def is_valid_model(cls, model: str) -> bool:
39
- """检查模型是否有效"""
40
- return model in cls.MODEL_MAPPING
41
-
42
- @classmethod
43
- def add_model_mapping(cls, openai_model: str, warp_model: str):
44
- """添加新的模型映射"""
45
- cls.MODEL_MAPPING[openai_model] = warp_model
46
-
47
- @classmethod
48
- def remove_model_mapping(cls, openai_model: str):
49
- """移除模型映射"""
50
- if openai_model in cls.MODEL_MAPPING:
51
- del cls.MODEL_MAPPING[openai_model]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ API服务模块
3
+ 处理所有API请求逻辑,包括聊天完成和Token管理
4
  """
5
+ import json
6
+ import time
7
+ from datetime import datetime
8
+ from typing import List, Dict, Any, Generator
9
+ from loguru import logger
10
 
11
+ from config import Config
12
+ from utils import Utils
13
+ from token_manager import MultiTokenManager
14
+ from warp_client import WarpClient
15
+ from request_converter import RequestConverter
16
+ from model_mapper import ModelMapper
17
 
18
+
19
+ class ApiService:
20
+ """API服务类,处理所有业务逻辑"""
21
+
22
+ def __init__(self):
23
+ # 初始化Token管理器
24
+ self.token_manager = MultiTokenManager()
25
+
26
+ # 初始化Warp客户端
27
+ self.warp_client = WarpClient(self.token_manager)
28
+
29
+ logger.info("🚀 ApiService初始化完成")
30
+
31
+ def authenticate_request(self, auth_header: str) -> bool:
32
+ """验证API请求"""
33
+ if not auth_header:
34
+ return False
35
+
36
+ token = Utils.extract_bearer_token(auth_header)
37
+ if not token:
38
+ return False
39
+
40
+ return Utils.validate_api_key(token)
41
+
42
+ def get_models(self) -> Dict[str, Any]:
43
+ """获取支持的模型列表"""
44
+ models = []
45
+ for model_name in ModelMapper.get_available_models():
46
+ models.append({
47
+ "id": model_name,
48
+ "object": "model",
49
+ "created": Utils.get_current_timestamp(),
50
+ "owned_by": "warp-proxy"
51
+ })
52
+
53
+ return {
54
+ "object": "list",
55
+ "data": models
56
+ }
57
+
58
+ def chat_completion(self, request_data: Dict[str, Any], stream: bool = False) -> Generator[str, None, None]:
59
+ """处理聊天完成请求"""
60
+ request_id = Utils.generate_request_id()
61
+ # 解析请求
62
+ openai_request = RequestConverter.parse_openai_request(request_data)
63
+ model = openai_request.model
64
+ messages = openai_request.messages
65
+
66
+ logger.info(f"🎯 开始处理聊天请求 [ID: {request_id[:8]}] [模型: {model}] [流式: {stream}]")
67
+ start_time = time.time()
68
+
69
+ try:
70
+ # 创建protobuf数据
71
+ protobuf_data = self.warp_client.create_protobuf_data(messages, model)
72
+ if not protobuf_data:
73
+ error_msg = "创建请求数据失败"
74
+ logger.error(f"❌ {error_msg} [ID: {request_id[:8]}]")
75
+ yield self._create_error_response(error_msg, request_id)
76
+ return
77
+
78
+ # 发送请求并处理响应
79
+ response_chunks = 0
80
+ total_content = ""
81
+
82
+ logger.success(f"🚀 开始接收响应 [ID: {request_id[:8]}]")
83
+
84
+ for chunk_text in self.warp_client.send_request(protobuf_data):
85
+ if chunk_text:
86
+ response_chunks += 1
87
+ total_content += chunk_text
88
+
89
+ logger.debug(f"📦 响应块 #{response_chunks} [ID: {request_id[:8]}] [长度: {len(chunk_text)}]")
90
+
91
+ if stream:
92
+ # 流式响应
93
+ chunk_response = self._create_stream_chunk(chunk_text, request_id)
94
+ yield f"data: {json.dumps(chunk_response)}\n\n"
95
+ else:
96
+ # 非流式响应 - 等待完整内容
97
+ continue
98
+
99
+ # 处理响应结束
100
+ end_time = time.time()
101
+ duration = end_time - start_time
102
+
103
+ if stream:
104
+ # 发送结束标记
105
+ final_chunk = self._create_stream_end_chunk(request_id)
106
+ yield f"data: {json.dumps(final_chunk)}\n\n"
107
+ yield "data: [DONE]\n\n"
108
+
109
+ logger.success(f"✅ 流式响应完成 [ID: {request_id[:8]}] [块数: {response_chunks}] [耗时: {duration:.2f}s]")
110
+ else:
111
+ # 返回完整响应
112
+ response = self._create_complete_response(total_content, request_id)
113
+ yield response
114
+
115
+ logger.success(f"✅ 完整响应完成 [ID: {request_id[:8]}] [长度: {len(total_content)}] [耗时: {duration:.2f}s]")
116
+
117
+ except Exception as e:
118
+ logger.error(f"❌ 聊天请求处理失败 [ID: {request_id[:8]}]: {e}")
119
+ yield self._create_error_response(f"服务器内部错误: {str(e)}", request_id)
120
+
121
+ def _create_stream_chunk(self, content: str, request_id: str) -> Dict[str, Any]:
122
+ """创建流式响应块"""
123
+ return {
124
+ "id": f"chatcmpl-{request_id}",
125
+ "object": "chat.completion.chunk",
126
+ "created": int(time.time()),
127
+ "model": "gemini-2.0-flash",
128
+ "choices": [{
129
+ "index": 0,
130
+ "delta": {"content": content},
131
+ "finish_reason": None
132
+ }]
133
+ }
134
+
135
+ def _create_stream_end_chunk(self, request_id: str) -> Dict[str, Any]:
136
+ """创建流式响应结束块"""
137
+ return {
138
+ "id": f"chatcmpl-{request_id}",
139
+ "object": "chat.completion.chunk",
140
+ "created": int(time.time()),
141
+ "model": "gemini-2.0-flash",
142
+ "choices": [{
143
+ "index": 0,
144
+ "delta": {},
145
+ "finish_reason": "stop"
146
+ }]
147
+ }
148
+
149
+ def _create_complete_response(self, content: str, request_id: str) -> Dict[str, Any]:
150
+ """创建完整响应"""
151
+ return {
152
+ "id": f"chatcmpl-{request_id}",
153
+ "object": "chat.completion",
154
+ "created": int(time.time()),
155
+ "model": "gemini-2.0-flash",
156
+ "choices": [{
157
+ "index": 0,
158
+ "message": {
159
+ "role": "assistant",
160
+ "content": content
161
+ },
162
+ "finish_reason": "stop"
163
+ }],
164
+ "usage": {
165
+ "prompt_tokens": 0,
166
+ "completion_tokens": 0,
167
+ "total_tokens": 0
168
+ }
169
+ }
170
+
171
+ def _create_error_response(self, error_message: str, request_id: str) -> Dict[str, Any]:
172
+ """创建错误响应"""
173
+ return {
174
+ "error": {
175
+ "message": error_message,
176
+ "type": "api_error",
177
+ "code": "internal_error"
178
+ },
179
+ "id": request_id
180
+ }
181
+
182
+ def get_token_status(self) -> Dict[str, Any]:
183
+ """获取Token状态"""
184
+ try:
185
+ status = self.token_manager.get_token_status()
186
+ return {"success": True, **status}
187
+ except Exception as e:
188
+ logger.error(f"❌ 获取Token状态失败: {e}")
189
+ return {"success": False, "message": str(e)}
190
+
191
+ def add_tokens(self, tokens: List[str]) -> Dict[str, Any]:
192
+ """添加Token"""
193
+ try:
194
+ success = self.token_manager.add_refresh_tokens(tokens)
195
+ if success:
196
+ valid_tokens = [t for t in tokens if Utils.validate_refresh_token_format(t)]
197
+ return {
198
+ "success": True,
199
+ "message": "Token添加成功",
200
+ "added_tokens": len(valid_tokens)
201
+ }
202
+ else:
203
+ return {"success": False, "message": "没有有效的Token可添加"}
204
+ except Exception as e:
205
+ logger.error(f"❌ 添加Token失败: {e}")
206
+ return {"success": False, "message": str(e)}
207
+
208
+ def remove_refresh_token(self, refresh_token: str) -> Dict[str, Any]:
209
+ """删除refresh token"""
210
+ try:
211
+ success = self.token_manager.remove_refresh_token(refresh_token)
212
+ if success:
213
+ return {"success": True, "message": "Token删除成功"}
214
+ else:
215
+ return {"success": False, "message": "Token不存在"}
216
+ except Exception as e:
217
+ logger.error(f"❌ 删除Token失败: {e}")
218
+ return {"success": False, "message": str(e)}
219
+
220
+ def refresh_all_tokens(self) -> Dict[str, Any]:
221
+ """刷新所有Token"""
222
+ try:
223
+ self.token_manager.refresh_all_tokens()
224
+ return {"success": True, "message": "Token刷新已开始"}
225
+ except Exception as e:
226
+ logger.error(f"❌ 刷新Token失败: {e}")
227
+ return {"success": False, "message": str(e)}
228
+
229
+ def export_refresh_tokens(self, super_admin_key: str) -> Dict[str, Any]:
230
+ """导出refresh token内容(需要超级管理员密钥验证)"""
231
+ try:
232
+ # 验证超级管理员密钥
233
+ if Config.require_super_admin_auth():
234
+ if not super_admin_key or super_admin_key != Config.get_super_admin_key():
235
+ return {"success": False, "message": "超级管理员密钥验证失败"}
236
+
237
+ # 获取所有refresh token
238
+ with self.token_manager.token_lock:
239
+ refresh_tokens = list(self.token_manager.tokens.keys())
240
+
241
+ if not refresh_tokens:
242
+ return {"success": False, "message": "没有可导出的token"}
243
+
244
+ # 创建分号分割的token字符串
245
+ token_string = ";".join(refresh_tokens)
246
+
247
+ # 生成建议的文件名(带时间戳)
248
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
249
+ suggested_filename = f"refresh_tokens_export_{timestamp}.txt"
250
+
251
+ logger.info(f"🔒 超级管理员请求导出 {len(refresh_tokens)} 个refresh token")
252
+
253
+ return {
254
+ "success": True,
255
+ "message": f"准备导出 {len(refresh_tokens)} 个token",
256
+ "content": token_string,
257
+ "suggested_filename": suggested_filename,
258
+ "token_count": len(refresh_tokens)
259
+ }
260
+
261
+ except Exception as e:
262
+ logger.error(f"❌ 准备导出refresh token失败: {e}")
263
+ return {"success": False, "message": f"导出失败: {str(e)}"}
264
+
265
+ def batch_get_refresh_tokens(self, email_url_dict: Dict[str, str], max_workers: int = 5) -> Dict[str, Any]:
266
+ """批量获取refresh token并自动创建用户"""
267
+ try:
268
+ from login_client import LoginClient
269
+ login_client = LoginClient()
270
+
271
+ # 传递token_manager参数,这样在获取refresh_token后会立即尝试创建用户
272
+ results = login_client.batch_process_emails(email_url_dict, max_workers, self.token_manager)
273
+
274
+ # 提取有效的token并添加到管理器
275
+ valid_tokens = []
276
+ for email, result in results.items():
277
+ if result.get('refresh_token'):
278
+ valid_tokens.append(result['refresh_token'])
279
+
280
+ if valid_tokens:
281
+ self.token_manager.add_refresh_tokens(valid_tokens)
282
+ logger.info(f"✅ 批量获取并添加了 {len(valid_tokens)} 个有效token")
283
+
284
+ return {
285
+ 'success': True,
286
+ 'results': results,
287
+ 'total_count': len(email_url_dict),
288
+ 'success_count': len(valid_tokens)
289
+ }
290
+
291
+ except Exception as e:
292
+ logger.error(f"❌ 批量获取refresh token失败: {e}")
293
+ return {'success': False, 'message': str(e)}
294
+
295
+ def start_services(self):
296
+ """启动后台服务"""
297
+ try:
298
+ self.token_manager.start_auto_refresh()
299
+ logger.success("✅ 后台服务启动成功")
300
+ except Exception as e:
301
+ logger.error(f"❌ 启动后台服务失败: {e}")
302
+
303
+ def stop_services(self):
304
+ """停止后台服务"""
305
+ try:
306
+ self.token_manager.stop_auto_refresh()
307
+ logger.info("⏹️ 后台服务已停止")
308
+ except Exception as e:
309
+ logger.error(f"❌ 停止后台服务失败: {e}")