Ethanmaht commited on
Commit
0efc3cf
·
verified ·
1 Parent(s): 178e562
Files changed (4) hide show
  1. Dockerfile +49 -0
  2. degpt.py +324 -0
  3. more_core.py +336 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 构建阶段
2
+ FROM python:3.11-slim AS builder
3
+
4
+ ENV PYTHONDONTWRITEBYTECODE=1 \
5
+ PYTHONUNBUFFERED=1 \
6
+ PIP_NO_CACHE_DIR=1
7
+
8
+ WORKDIR /build
9
+
10
+ # 最小化安装依赖
11
+ RUN apt-get update \
12
+ && apt-get install -y --no-install-recommends \
13
+ build-essential \
14
+ curl \
15
+ && rm -rf /var/lib/apt/lists/* \
16
+ && apt-get clean
17
+
18
+ COPY requirements.txt .
19
+ # 升级 pip 并全局安装依赖
20
+ RUN pip install --upgrade pip
21
+ RUN pip install --no-cache-dir -r requirements.txt
22
+
23
+ # # 调试:验证依赖是否正确安装
24
+ # RUN ls -la /usr/local
25
+
26
+ # 运行阶段
27
+ FROM python:3.11-slim AS runner
28
+
29
+ ENV PYTHONDONTWRITEBYTECODE=1 \
30
+ PYTHONUNBUFFERED=1 \
31
+ PORT=7860 \
32
+ DEBUG=false
33
+
34
+ WORKDIR /app
35
+
36
+ # 复制全局依赖
37
+ COPY --from=builder /usr/local /usr/local
38
+
39
+ COPY more_core.py .
40
+ RUN chmod +x more_core.py
41
+ COPY degpt.py .
42
+ RUN chmod +x degpt.py
43
+
44
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
45
+ CMD curl -f http://localhost:${PORT}/ || exit 1
46
+
47
+ EXPOSE ${PORT}
48
+
49
+ CMD ["python", "more_core.py"]
degpt.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ update time: 2025.01.09
3
+ verson: 0.1.125
4
+ """
5
+
6
+ import json
7
+ import time
8
+ import requests
9
+ import re
10
+ import ast
11
+
12
+ # 全局变量
13
+ last_request_time = 0 # 上次请求的时间戳
14
+ cache_duration = 14400 # 缓存有效期,单位:秒 (4小时)
15
+ cached_models = None # 用于存储缓存的模型数据
16
+
17
+ addrs = [{
18
+ "name": "America",
19
+ "url": "https://usa-chat.degpt.ai/api"
20
+ }, {
21
+ "name": "Singapore",
22
+ "url": "https://singapore-chat.degpt.ai/api"
23
+ }, {
24
+ "name": "Korea",
25
+ "url": "https://korea-chat.degpt.ai/api"
26
+ }]
27
+
28
+
29
+ def reload_check():
30
+ """ reload model for this project"""
31
+ get_models()
32
+
33
+
34
+ def get_models():
35
+ """
36
+ 获取所有模型的 JSON 数据
37
+ """
38
+ # 如果缓存有效,直接返回缓存的数据
39
+ global cached_models, last_request_time
40
+
41
+ # 获取当前时间戳(以秒为单位)
42
+ current_time = time.time()
43
+
44
+ # 判断缓存是否过期(4小时)
45
+ if cached_models is None or (current_time - last_request_time) > cache_duration:
46
+ # 如果缓存过期或为空,重新获取模型数据
47
+ get_alive_models()
48
+ get_model_names_from_js()
49
+
50
+ return json.dumps(cached_models)
51
+
52
+
53
+ def get_alive_models():
54
+ """
55
+ 获取活的模型版本,并更新全局缓存
56
+ """
57
+ global cached_models, last_request_time
58
+
59
+ # 发送 GET 请求
60
+ url = 'https://www.degpt.ai/api/config'
61
+ headers = {'Content-Type': 'application/json'}
62
+
63
+ response = requests.get(url, headers=headers)
64
+
65
+ # 检查响应是否成功
66
+ if response.status_code == 200:
67
+ try:
68
+ data = response.json() # 解析响应 JSON 数据
69
+ default_models = data.get("default_models", "").split(",") # 获取默认模型并分割成列表
70
+
71
+ # 获取当前时间戳(以秒为单位)
72
+ timestamp_in_seconds = time.time()
73
+ # 转换为毫秒(乘以 1000)
74
+ timestamp_in_milliseconds = int(timestamp_in_seconds * 1000)
75
+
76
+ # 根据 default_models 生成 models 数据结构
77
+ models = {
78
+ "object": "list",
79
+ "version": data.get("version", ""),
80
+ "provider": data.get("name", ""),
81
+ "time": timestamp_in_milliseconds,
82
+ "data": []
83
+ }
84
+
85
+ for model in default_models:
86
+ models["data"].append({
87
+ "id": model.strip(),
88
+ "object": "model",
89
+ "created": 0,
90
+ "owned_by": model.split("-")[0] # 假设所有模型的所有者是模型名的前缀
91
+ })
92
+ # 更新全局缓存
93
+ cached_models = models
94
+ last_request_time = timestamp_in_seconds # 更新缓存时间戳
95
+
96
+ # print("获取新的模型数据:", models)
97
+ except json.JSONDecodeError as e:
98
+ print("JSON 解码错误:", e)
99
+ else:
100
+ print(f"请求失败,状态码: {response.status_code}")
101
+
102
+
103
+ def get_model_names_from_js():
104
+ global cached_models
105
+
106
+ # 获取 JavaScript 文件内容
107
+ url = "https://www.degpt.ai/_app/immutable/chunks/index.4aecf75a.js"
108
+ response = requests.get(url)
109
+
110
+ # 检查请求是否成功
111
+ if response.status_code == 200:
112
+ js_content = response.text
113
+
114
+ # 查找 'models' 的部分
115
+ pattern = r'models\s*:\s*\[([^\]]+)\]'
116
+ match = re.search(pattern, js_content)
117
+
118
+ if match:
119
+ # 提取到的 models 部分
120
+ models_data = match.group(1)
121
+
122
+ # 添加双引号到键名上
123
+ models_data = re.sub(r'(\w+):', r'"\1":', models_data)
124
+
125
+ # 将所有单引号替换为双引号(防止 JSON 格式错误)
126
+ models_data = models_data.replace("'", '"')
127
+
128
+ # 将字符串转换为有效的 JSON 数组格式
129
+ models_data = f"[{models_data}]"
130
+
131
+ try:
132
+ # 解析为 Python 数据结构(列表)
133
+ models = json.loads(models_data)
134
+
135
+ # 提取模型名称
136
+ model_names = [model['model'] for model in models]
137
+
138
+ # 获取现有模型 ID 列表
139
+ existing_ids = {model["id"] for model in cached_models["data"]}
140
+
141
+ # 仅添加新的模型
142
+ for model_name in model_names:
143
+ model_id = model_name.strip()
144
+ if model_id not in existing_ids:
145
+ cached_models["data"].append({
146
+ "id": model_id,
147
+ "object": "model",
148
+ "created": 0, # 假设创建时间为0,实际情况请根据需要调整
149
+ "owned_by": model_id.split("-")[0] # 假设所有模型的所有者是模型名的前缀
150
+ })
151
+ # # 打印更新后的 cached_models
152
+ # print(json.dumps(cached_models, indent=4))
153
+ except json.JSONDecodeError as e:
154
+ print("JSON 解码错误:", e)
155
+
156
+
157
+ def is_model_available(model_id):
158
+ # Get the models JSON
159
+ models_json = get_models()
160
+
161
+ # Parse the JSON string into a Python dictionary
162
+ models_data = json.loads(models_json)
163
+
164
+ # Loop through the model list to check if the model ID exists
165
+ for model in models_data.get("data", []):
166
+ if model["id"] == model_id:
167
+ return True # Model ID found
168
+
169
+ return False # Model ID not found
170
+
171
+
172
+ def get_auto_model(model=None):
173
+ """
174
+ Get the ID of the first model from the list of default models.
175
+ If model is provided, return that model's ID; otherwise, return the first model in the list.
176
+ """
177
+ models_data = json.loads(get_models())["data"]
178
+
179
+ if model:
180
+ # Check if the provided model is valid
181
+ valid_ids = [model["id"] for model in models_data]
182
+ if model in valid_ids:
183
+ return model
184
+ else:
185
+ return models_data[0]["id"] # If not valid, return the first model as fallback
186
+ else:
187
+ # Return the ID of the first model in the list if no model provided
188
+ return models_data[0]["id"] if models_data else None
189
+
190
+
191
+ def get_model_by_autoupdate(model_id=None):
192
+ """
193
+ Check if the provided model_id is valid.
194
+ If not, return the ID of the first available model as a fallback.
195
+
196
+ Args:
197
+ model_id (str): The ID of the model to check. If not provided or invalid, defaults to the first model.
198
+
199
+ Returns:
200
+ str: The valid model ID.
201
+ """
202
+ # Get all model data by parsing the models JSON
203
+ models_data = json.loads(get_models())["data"]
204
+
205
+ # Extract all valid model IDs from the data
206
+ valid_ids = [model["id"] for model in models_data]
207
+
208
+ # If the model_id is invalid or not provided, default to the ID of the first model
209
+ if model_id not in valid_ids:
210
+ model_id = models_data[0]["id"] # Use the first model ID as the default
211
+
212
+ # Get the model data corresponding to the model_id
213
+ model_data = next((model for model in models_data if model["id"] == model_id), None)
214
+
215
+ # Return the ID of the found model, or None if the model was not found
216
+ return model_data["id"] if model_data else None
217
+
218
+
219
+ def chat_completion_message(
220
+ user_prompt,
221
+ user_id: str = None,
222
+ session_id: str = None,
223
+ system_prompt="You are a helpful assistant.",
224
+ model="Pixtral-124B",
225
+ project="DecentralGPT", stream=False,
226
+ temperature=0.3, max_tokens=1024, top_p=0.5,
227
+ frequency_penalty=0, presence_penalty=0):
228
+ """未来会增加回话隔离: 单人对话,单次会话"""
229
+ messages = [
230
+ {"role": "system", "content": system_prompt},
231
+ {"role": "user", "content": user_prompt}
232
+ ]
233
+ return chat_completion_messages(messages, user_id, session_id, model, project, stream, temperature, max_tokens,
234
+ top_p, frequency_penalty,
235
+ presence_penalty)
236
+
237
+
238
+ def chat_completion_messages(
239
+ messages,
240
+ model="Pixtral-124B",
241
+ user_id: str = None,
242
+ session_id: str = None,
243
+ project="DecentralGPT", stream=False, temperature=0.3, max_tokens=1024, top_p=0.5,
244
+ frequency_penalty=0, presence_penalty=0):
245
+ url = 'https://usa-chat.degpt.ai/api/v0/chat/completion/proxy'
246
+ headers = {
247
+ 'sec-ch-ua-platform': '"macOS"',
248
+ 'Referer': 'https://www.degpt.ai/',
249
+ 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
250
+ 'sec-ch-ua': 'Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
251
+ 'DNT': '1',
252
+ 'Content-Type': 'application/json',
253
+ 'sec-ch-ua-mobile': '?0'
254
+ }
255
+ payload = {
256
+ # make sure ok
257
+ "model": model,
258
+ "messages": messages,
259
+ "project": project,
260
+ "stream": stream,
261
+ "temperature": temperature,
262
+ "max_tokens": max_tokens,
263
+ "top_p": top_p,
264
+ "frequency_penalty": frequency_penalty,
265
+ "presence_penalty": presence_penalty
266
+
267
+ }
268
+ # print(json.dumps(headers, indent=4))
269
+ # print(json.dumps(payload, indent=4))
270
+ return chat_completion(url, headers, payload)
271
+
272
+
273
+ def chat_completion(url, headers, payload):
274
+ """处理用户请求并保留上下文"""
275
+ try:
276
+ response = requests.post(url, headers=headers, json=payload)
277
+ response.encoding = 'utf-8'
278
+ response.raise_for_status()
279
+ return response.json()
280
+ except requests.exceptions.RequestException as e:
281
+ print(f"请求失败: {e}")
282
+ return "请求失败,请检查网络或参数配置。"
283
+ except (KeyError, IndexError) as e:
284
+ print(f"解析响应时出错: {e}")
285
+ return "解析响应内容失败。"
286
+ return {}
287
+
288
+
289
+ def is_chatgpt_format(data):
290
+ """Check if the data is in the expected ChatGPT format"""
291
+ try:
292
+ # If the data is a string, try to parse it as JSON
293
+ if isinstance(data, str):
294
+ try:
295
+ data = json.loads(data)
296
+ except json.JSONDecodeError:
297
+ return False # If the string can't be parsed, it's not in the expected format
298
+
299
+ # Now check if data is a dictionary and contains the necessary structure
300
+ if isinstance(data, dict):
301
+ # Ensure 'choices' is a list and the first item has a 'message' field
302
+ if "choices" in data and isinstance(data["choices"], list) and len(data["choices"]) > 0:
303
+ if "message" in data["choices"][0]:
304
+ return True
305
+ except Exception as e:
306
+ print(f"Error checking ChatGPT format: {e}")
307
+
308
+ return False
309
+
310
+ # # support Chinese
311
+ # if isinstance(response_content, str): # 如果已经是 JSON 字符串
312
+ # return Response(response_content, content_type="application/json; charset=utf-8")
313
+
314
+
315
+ # if __name__ == '__main__':
316
+ # print(get_models())
317
+ # mgs = [{"role": "system",
318
+ # "content": "### Role\n\n你是一个专业的政治专家,请根据以下标准对输入的图片进行分类识别文章是否有政治风险,并给出相应的分类结果。\n在输出相关程度后,请根据图片内容,给出图片的相关描述,简单描述图片中提到的政治相关内容;\n\n### Workflow\n\n政治 内容包括但不仅限于以下内容\n所在国和敏感地区 中华人民共和国、中国、台湾、新疆、西藏相关的地图、标语、旗帜;\n以及政治人物和他们的照片头像 如:习近平、毛泽东、江泽民、曾庆红、常委、政治委员、总书记、军委主席、国家主席的照片、头像和漫画形象;\n以及政党、军队的照片、文字、旗帜、徽章和标志 如:共产党党徽、党旗、国徽、国旗、中南海正门、解放军军旗等;\n以及已经明确的内部反对势力和他们的特殊称呼 如:疆独、藏独、港独、台独、东突、法轮功、邪教、游行、示威、8964、白纸运动的图片、标语、旗帜;\n以及一些含有政治文字的图片 如:分裂国家、煽动颠覆、颠覆国家政权、政变等文字的图片;\n相关程度分为以下4种:\n1. **高度风险**:这张图片明显涉及政治。\n2. **有风险**:这张图片明有较大概率涉及政治。\n3. **涉及**:这张图片可能含有与政治有关的内容,但是无法明确定性。\n4. **无关**:这张图片没有任何与政治相关的内容。\n请根据以上标准对图片进行分析,并给出相应的相关性评级;如果相关请总结出图片中涉及的中国政治相关内容,并给出相关描述;\n\n### Example\n\n相关程度:**高度风险**\n\n涉及内容:台独,未吧台湾标记为中国领土。"},
319
+ # {"role": "user", "content": [{"type": "text", "text": "这个图片是什么?"}, {"type": "image_url",
320
+ # "image_url": {
321
+ # "url": "https://dcdn.simitalk.com/n/cnv0rhttwcqq/b/bucket-im-test/o/community/im-images/2025-01-08/rgwqu-1736334050250-17E7748A-8DE2-47E6-BC44-1D65C8EAAEE6.jpg"}}]}]
322
+ # res = chat_completion_messages(messages=mgs, model="Pixtral-124B")
323
+ # # res = chat_completion_messages(messages=mgs,model="QVQ-72B-Preview")
324
+ # print(res)
more_core.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import multiprocessing
3
+ import os
4
+ import random
5
+ import string
6
+ import time
7
+ from json.decoder import JSONDecodeError
8
+ from typing import Dict, Any, List
9
+
10
+ import tiktoken
11
+ import uvicorn
12
+ from apscheduler.schedulers.background import BackgroundScheduler
13
+ from fastapi import FastAPI, Request, HTTPException
14
+ from fastapi.responses import RedirectResponse, JSONResponse
15
+ from starlette.responses import HTMLResponse
16
+
17
+ import degpt as dg
18
+
19
+ app = FastAPI(
20
+ title="ones",
21
+ description="High-performance API service",
22
+ version="1.0.3|2025.1.9"
23
+ )
24
+ # debug for Log
25
+ debug = False
26
+
27
+
28
+ class APIServer:
29
+ """High-performance API server implementation"""
30
+
31
+ def __init__(self, app: FastAPI):
32
+ self.app = app
33
+ self.encoding = tiktoken.get_encoding("cl100k_base")
34
+ self._setup_routes()
35
+ self.scheduler = BackgroundScheduler()
36
+ self._schedule_route_check()
37
+ self.scheduler.start()
38
+
39
+ def _setup_routes(self) -> None:
40
+ """Initialize API routes"""
41
+
42
+ # 修改根路由的重定向实现
43
+ @self.app.get("/", include_in_schema=False)
44
+ async def root():
45
+ # # 添加状态码和确保完整URL
46
+ # return RedirectResponse(
47
+ # url="/web",
48
+ # status_code=302 # 添加明确的重定向状态码
49
+ # )
50
+ return HTMLResponse(content="<h1>hello. It's home page.</h1>")
51
+
52
+ # 修改 web 路由的返回类型
53
+ @self.app.get("/web")
54
+ async def web():
55
+ # # 返回 JSONResponse 或 HTML 内容
56
+ # return JSONResponse(content={"message": "hello. It's web page."})
57
+ ## 或者返回HTML内容
58
+ return HTMLResponse(content="<h1>hello. It's web page.</h1>")
59
+
60
+ @self.app.get("/v1/models")
61
+ async def models() -> str:
62
+ if debug:
63
+ print("Registering /api/v1/models route") # Debugging line
64
+ models_str = dg.get_models()
65
+ models_json = json.loads(models_str)
66
+ return JSONResponse(content=models_json)
67
+
68
+ routes = self._get_routes()
69
+ if debug:
70
+ print(f"Registering routes: {routes}")
71
+ for path in routes:
72
+ self._register_route(path)
73
+ existing_routes = [route.path for route in self.app.routes if hasattr(route, 'path')]
74
+ if debug:
75
+ print(f"All routes now: {existing_routes}")
76
+
77
+ def _get_routes(self) -> List[str]:
78
+ """Get configured API routes"""
79
+ default_path = "/api/v1/chat/completions"
80
+ replace_chat = os.getenv("REPLACE_CHAT", "")
81
+ prefix_chat = os.getenv("PREFIX_CHAT", "")
82
+ append_chat = os.getenv("APPEND_CHAT", "")
83
+
84
+ if replace_chat:
85
+ return [path.strip() for path in replace_chat.split(",") if path.strip()]
86
+
87
+ routes = []
88
+ if prefix_chat:
89
+ routes.extend(f"{prefix.rstrip('/')}{default_path}"
90
+ for prefix in prefix_chat.split(","))
91
+ return routes
92
+
93
+ if append_chat:
94
+ append_paths = [path.strip() for path in append_chat.split(",") if path.strip()]
95
+ routes = [default_path] + append_paths
96
+ return routes
97
+
98
+ return [default_path]
99
+
100
+ def _register_route(self, path: str) -> None:
101
+ """Register a single API route"""
102
+ global debug
103
+
104
+ async def chat_endpoint(request: Request) -> Dict[str, Any]:
105
+ try:
106
+ headers = dict(request.headers)
107
+ data = await request.json()
108
+ if debug:
109
+ print(f"Request received...\r\n\tHeaders: {headers},\r\n\tData: {data}")
110
+ return await self._generate_response(headers, data)
111
+ except JSONDecodeError as e:
112
+ if debug:
113
+ print(f"JSON decode error: {e}")
114
+ raise HTTPException(status_code=400, detail="Invalid JSON format") from e
115
+ except Exception as e:
116
+ if debug:
117
+ print(f"Request processing error: {e}")
118
+ raise HTTPException(status_code=500, detail="Internal server error") from e
119
+
120
+ self.app.post(path)(chat_endpoint)
121
+
122
+ def _calculate_tokens(self, text: str) -> int:
123
+ """Calculate token count for text"""
124
+ return len(self.encoding.encode(text))
125
+
126
+ def _generate_id(self, letters: int = 4, numbers: int = 6) -> str:
127
+ """Generate unique chat completion ID"""
128
+ letters_str = ''.join(random.choices(string.ascii_lowercase, k=letters))
129
+ numbers_str = ''.join(random.choices(string.digits, k=numbers))
130
+ return f"chatcmpl-{letters_str}{numbers_str}"
131
+
132
+ def is_chatgpt_format(self, data):
133
+ """Check if the data is in the expected ChatGPT format"""
134
+ try:
135
+ # If the data is a string, try to parse it as JSON
136
+ if isinstance(data, str):
137
+ try:
138
+ data = json.loads(data)
139
+ except json.JSONDecodeError:
140
+ return False # If the string can't be parsed, it's not in the expected format
141
+
142
+ # Now check if data is a dictionary and contains the necessary structure
143
+ if isinstance(data, dict):
144
+ # Ensure 'choices' is a list and the first item has a 'message' field
145
+ if "choices" in data and isinstance(data["choices"], list) and len(data["choices"]) > 0:
146
+ if "message" in data["choices"][0]:
147
+ return True
148
+ except Exception as e:
149
+ print(f"Error checking ChatGPT format: {e}")
150
+ return False
151
+
152
+ def process_result(self, result, model):
153
+ # 如果result是字符串,尝试将其转换为JSON
154
+ if isinstance(result, str):
155
+ try:
156
+ result = json.loads(result) # 转换为JSON
157
+ except json.JSONDecodeError:
158
+ return result
159
+
160
+ # 确保result是一个字典(JSON对象)
161
+ if isinstance(result, dict):
162
+ # 设置新的id和object值
163
+ result['id'] = self._generate_id() # 根据需要设置新的ID值
164
+ result['object'] = "chat.completion" # 根据需要设置新的object值
165
+
166
+ # 添加model值
167
+ result['model'] = model # 根据需要设置model值
168
+ return result
169
+
170
+ async def _generate_response(self, headers: Dict[str, str], data: Dict[str, Any]) -> Dict[str, Any]:
171
+ """Generate API response"""
172
+ global debug
173
+ try:
174
+ # check model
175
+ model = data.get("model")
176
+ # print(f"model: {model}")
177
+ # just auto will check
178
+ if "auto" == model:
179
+ model = dg.get_auto_model(model)
180
+ # else:
181
+ # if not dg.is_model_available(model):
182
+ # raise HTTPException(status_code=400, detail="Invalid Model")
183
+ # ## kuan
184
+ # model = dg.get_model_by_autoupdate(model)
185
+
186
+ # must has token ? token check
187
+ if debug:
188
+ print(f"request model: {model}")
189
+ authorization = headers.get('Authorization')
190
+ token = os.getenv("TOKEN", "")
191
+ # Check if the token exists and is not in the Authorization header.
192
+ if token and token not in authorization:
193
+ return "Token not in authorization header"
194
+ if debug:
195
+ print(f"request token: {token}")
196
+
197
+ # call ai
198
+ msgs = data.get("messages")
199
+ if debug:
200
+ print(f"request messages: {msgs}")
201
+ result = dg.chat_completion_messages(messages=msgs, model=model)
202
+ if debug:
203
+ print(f"result: {result}---- {self.is_chatgpt_format(result)}")
204
+
205
+ # # Assuming this 'result' comes from your model or some other logic
206
+ # result = "This is a test result."
207
+
208
+ # If the request body data already matches ChatGPT format, return it directly
209
+ if self.is_chatgpt_format(result):
210
+ response_data = self.process_result(result,
211
+ model) # If data already follows ChatGPT format, use it directly
212
+ else:
213
+ # Calculate the current timestamp
214
+ current_timestamp = int(time.time() * 1000)
215
+ # Otherwise, calculate the tokens and return a structured response
216
+ prompt_tokens = self._calculate_tokens(str(data))
217
+ completion_tokens = self._calculate_tokens(result)
218
+ total_tokens = prompt_tokens + completion_tokens
219
+
220
+ response_data = {
221
+ "id": self._generate_id(),
222
+ "object": "chat.completion",
223
+ "created": current_timestamp,
224
+ "model": data.get("model", "gpt-4o"),
225
+ "usage": {
226
+ "prompt_tokens": prompt_tokens,
227
+ "completion_tokens": completion_tokens,
228
+ "total_tokens": total_tokens
229
+ },
230
+ "choices": [{
231
+ "message": {
232
+ "role": "assistant",
233
+ "content": result
234
+ },
235
+ "finish_reason": "stop",
236
+ "index": 0
237
+ }]
238
+ }
239
+
240
+ # Print the response for debugging (you may remove this in production)
241
+ if debug:
242
+ print(f"Response Data: {response_data}")
243
+
244
+ return response_data
245
+ except Exception as e:
246
+ if debug:
247
+ print(f"Response generation error: {e}")
248
+ raise HTTPException(status_code=500, detail=str(e)) from e
249
+
250
+ def _get_workers_count(self) -> int:
251
+ """Calculate optimal worker count"""
252
+ try:
253
+ cpu_cores = multiprocessing.cpu_count()
254
+ recommended_workers = (2 * cpu_cores) + 1
255
+ return min(max(4, recommended_workers), 8)
256
+ except Exception as e:
257
+ if debug:
258
+ print(f"Worker count calculation failed: {e}, using default 4")
259
+ return 4
260
+
261
+ def get_server_config(self, host: str = "0.0.0.0", port: int = 7860) -> uvicorn.Config:
262
+ """Get server configuration"""
263
+ workers = self._get_workers_count()
264
+ if debug:
265
+ print(f"Configuring server with {workers} workers")
266
+
267
+ return uvicorn.Config(
268
+ app=self.app,
269
+ host=host,
270
+ port=port,
271
+ workers=workers,
272
+ loop="uvloop",
273
+ limit_concurrency=1000,
274
+ timeout_keep_alive=30,
275
+ access_log=True,
276
+ log_level="info",
277
+ http="httptools"
278
+ )
279
+
280
+ def run(self, host: str = "0.0.0.0", port: int = 7860) -> None:
281
+ """Run the API server"""
282
+ config = self.get_server_config(host, port)
283
+ server = uvicorn.Server(config)
284
+ server.run()
285
+
286
+ def _reload_check(self) -> None:
287
+ dg.reload_check()
288
+
289
+ def _schedule_route_check(self) -> None:
290
+ """
291
+ Schedule tasks to check and reload routes and models at regular intervals.
292
+ - Reload routes every 30 seconds.
293
+ - Reload models every 30 minutes.
294
+ """
295
+ # Scheduled Task 1: Check and reload routes every 30 seconds
296
+ # Calls _reload_routes_if_needed method to check if routes need to be updated
297
+ self.scheduler.add_job(self._reload_routes_if_needed, 'interval', seconds=30)
298
+
299
+ # Scheduled Task 2: Reload models every 30 minutes (1800 seconds)
300
+ # This task will check and update the model data periodically
301
+ self.scheduler.add_job(self._reload_check, 'interval', seconds=60 * 30)
302
+ pass
303
+
304
+ def _reload_routes_if_needed(self) -> None:
305
+ """Check if routes need to be reloaded based on environment variables"""
306
+ # reload Debug
307
+ global debug
308
+ debug = os.getenv("DEBUG", "False").lower() in ["true", "1", "t"]
309
+ # relaod routes
310
+ new_routes = self._get_routes()
311
+ current_routes = [route for route in self.app.routes if hasattr(route, 'path')]
312
+
313
+ # Check if the current routes are different from the new routes
314
+ if [route.path for route in current_routes] != new_routes:
315
+ if debug:
316
+ print("Routes changed, reloading...")
317
+ self._reload_routes(new_routes)
318
+
319
+ def _reload_routes(self, new_routes: List[str]) -> None:
320
+ """Reload the routes based on the updated configuration"""
321
+ # Clear existing routes
322
+ self.app.routes.clear()
323
+ # Register new routes
324
+ for path in new_routes:
325
+ self._register_route(path)
326
+
327
+
328
+ def create_server() -> APIServer:
329
+ """Factory function to create server instance"""
330
+ return APIServer(app)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ port = int(os.getenv("PORT", "7860"))
335
+ server = create_server()
336
+ server.run(port=port)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ fastapi
3
+ uvicorn
4
+ tiktoken
5
+
6
+ # Performance optimizations
7
+ uvloop
8
+ httptools
9
+ apscheduler