99i commited on
Commit
65c1969
·
verified ·
1 Parent(s): f03905f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -176
app.py CHANGED
@@ -1,7 +1,5 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Dict, Optional
3
- from fastapi import FastAPI, Request, HTTPException, status
4
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import httpx
7
  import json
@@ -9,16 +7,22 @@ import os
9
  import random
10
  from datetime import datetime
11
  import pytz
12
- import logging
13
 
14
- # 配置日志
15
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
16
 
17
- @dataclass
18
- class ApiConfig:
19
- """API 配置数据类"""
20
- base_url: str = os.environ.get("BASE_URL", "https://api.siliconflow.cn/v1")
21
- model_map: Dict[str, str] = {
 
 
 
22
  "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
23
  "qwen-32b": "Qwen/Qwen2.5-32B-Instruct",
24
  "qwen-14b": "Qwen/Qwen2.5-14B-Instruct",
@@ -33,25 +37,31 @@ class ApiConfig:
33
  "bce": "netease-youdao/bce-embedding-base_v1",
34
  "bge-m3": "BAAI/bge-m3",
35
  "bge-zh": "BAAI/bge-large-zh-v1.5",
36
- "sd":"stabilityai/stable-diffusion-3-5-large",
37
- "sd-turbo":"stabilityai/stable-diffusion-3-5-large-turbo",
38
- "flux-s":"black-forest-labs/FLUX.1-schnell",
39
- "flux-d":"black-forest-labs/FLUX.1-dev"
40
  }
41
- api_key_str: Optional[str] = os.environ.get("SI_KEY")
42
- timezone: str = "Asia/Shanghai"
43
-
44
- def __post_init__(self):
45
- if not self.api_key_str :
46
- logging.error("SI_KEY not found in env")
47
- raise EnvironmentError("SI_KEY not found in env")
48
- self.api_keys:list[str] = self.api_key_str.split(",")
49
- if os.environ.get("MODEL_MAP"):
50
- self.model_map = json.loads(os.environ.get("MODEL_MAP"))
51
 
52
- config = ApiConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
54
  app = FastAPI()
 
55
  app.add_middleware(
56
  CORSMiddleware,
57
  allow_origins=["*"],
@@ -60,171 +70,115 @@ app.add_middleware(
60
  allow_headers=["*"],
61
  )
62
 
63
- # API Key 管理类
64
- class ApiKeyManager:
65
- def __init__(self, keys:list[str]):
66
- self.keys = keys
67
- self.key_balance :Dict[str,float] = {}
68
- self.key_balance_notes = ""
69
- self.last_update_time = ""
70
-
71
- def get_key(self) -> str:
72
- """随机获取一个可用 API Key"""
73
- if not self.keys:
74
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="No available API keys")
75
- random.shuffle(self.keys)
76
- return self.keys[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- def remove_key(self,key:str):
79
- """移除不可用 API Key"""
80
- if key in self.keys:
81
- self.keys.remove(key)
82
- else:
83
- logging.warning("try remove a not exists key from key_pool")
84
-
85
- async def check_keys_balance(self):
86
- """检查所有 API Key 的余额,并移除余额不足的 API Key"""
87
- self.key_balance_notes = ""
88
- key_to_remove = []
89
-
90
- for key in self.keys:
91
- try:
92
- balance = await self._fetch_balance(key)
93
- if balance < 0.1:
94
- key_to_remove.append(key)
95
- else:
96
- balance_info = f"<h2>{key.strip()[0:4]}****{key.strip()[-4:]}————{balance}</h2>"
97
- self.key_balance[key.strip()] = balance
98
- self.key_balance_notes += balance_info
99
- except HTTPException as e:
100
- logging.error(f"Key {key} check balance failed, detail:{e.detail}")
101
- key_to_remove.append(key)
102
- except Exception as e:
103
- logging.error(f"Key {key} check balance failed, unexcept error:{e}")
104
- key_to_remove.append(key)
105
- for remove_key in key_to_remove:
106
- self.remove_key(remove_key)
107
- self.last_update_time = datetime.now(pytz.timezone(config.timezone))
108
-
109
- async def _fetch_balance(self, key: str) -> float:
110
- """发送 API 请求,获取 API Key 的余额"""
111
- url = "https://api.siliconflow.cn/v1/user/info"
112
- headers = {"Authorization": f"Bearer {key.strip()}"}
113
  async with httpx.AsyncClient() as client:
114
- try:
115
- res = await client.get(url, headers=headers)
116
- res.raise_for_status()
117
- balance = res.json()["data"]["balance"]
118
- return float(balance)
119
- except httpx.HTTPError as exc:
120
- logging.error("httpx request error, detail:" + str(exc))
121
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Check balance failed with status:{exc.response.status_code},url:{exc.request.url}")
122
- key_manager = ApiKeyManager(config.api_keys)
123
-
124
-
125
- async def get_model_info() -> str:
126
- models = ""
127
- for key, value in config.model_map.items():
128
- models += f"<h2>{key}————{value}</h2>"
129
- return models
130
 
 
131
  @app.get("/", response_class=HTMLResponse)
132
- async def root() -> HTMLResponse:
133
- """根路由,返回 HTML 页面,展示模型信息和更新时间"""
134
- models_info = await get_model_info()
135
- return HTMLResponse(f"""
 
 
 
136
  <html>
137
  <head>
138
- <title>富文本示例</title>
139
  </head>
140
  <body>
141
- <h1>有效key数量:{len(key_manager.keys)}</h1>
142
  {models_info}
143
- <h1>最后更新时间:{key_manager.last_update_time}</h1>
144
- {key_manager.key_balance_notes}
145
  </body>
146
  </html>
147
- """)
148
 
149
 
150
  @app.get("/check")
151
- async def check() -> str:
152
- """手动触发检查 API Key 余额的路由"""
153
- await key_manager.check_keys_balance()
154
- return f"更新成功:{key_manager.last_update_time}"
155
-
156
- def is_chat_stream_request(request:Request) -> bool:
157
- headers = request.headers
158
- if headers.get("content-type") == 'application/json':
159
- try:
160
- request_body=json.loads(request.headers._list[4][1])
161
- return request_body.get('model') is not None and request_body.get("stream") is True
162
- except Exception as e:
163
- logging.error("parse request body error detail:" + str(e))
164
- return False
165
- return False
166
-
167
- async def _forward_request(request: Request, api_endpoint: str, is_stream:bool = False ) -> Any:
168
- """转发请求到硅流 API"""
169
- try:
170
- key = key_manager.get_key()
171
- except HTTPException as e:
172
- return e
173
- logging.info(f"using key {key[0:4]}***{key[-4:]} to {api_endpoint}")
174
- headers = {"Authorization": f"Bearer {key}"}
175
- url=f"{config.base_url}{api_endpoint}"
176
- if api_endpoint == "/embeddings":
177
- body=await request.json()
178
- if "model" in body and body["model"] in config.model_map:
179
- body["model"] = config.model_map[body["model"]]
180
- async with httpx.AsyncClient() as client:
181
- try:
182
- response = await client.post(url,headers=headers, json=body)
183
- response.raise_for_status()
184
- return response.json()
185
- except httpx.HTTPError as exc :
186
- logging.error(f"httpx request embedding error :{exc}")
187
- raise HTTPException(
188
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
189
- detail=f"Request failed with status: {exc.response.status_code}, url:{exc.request.url}, detail:{exc.response.text}"
190
- )
191
- elif is_stream:
192
- body = await request.json()
193
- async def generate_response():
194
- async with httpx.AsyncClient() as client:
195
- async with client.stream("POST",url, headers=headers, json=body) as response:
196
- response.raise_for_status() # 检查响应状态码
197
- async for chunk in response.aiter_bytes():
198
- if chunk:
199
- yield chunk
200
- return StreamingResponse(generate_response(), media_type="text/event-stream")
201
-
202
- else :
203
- body = await request.json()
204
- async with httpx.AsyncClient() as client:
205
- try:
206
- response = await client.post(url, headers=headers, json=body)
207
- response.raise_for_status()
208
- return response.json()
209
- except httpx.HTTPError as exc :
210
- logging.error(f"httpx request error:{exc}")
211
- raise HTTPException(
212
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
213
- detail=f"Request failed with status: {exc.response.status_code},url:{exc.request.url},detail:{exc.response.text}"
214
- )
215
  @app.post("/hf/v1/chat/completions")
216
- async def chat_completions(request: Request) -> Any:
217
- """转发 chat 完成请求"""
218
- return await _forward_request(request, "/chat/completions",is_chat_stream_request(request))
219
 
220
  @app.post("/hf/v1/embeddings")
221
- async def embeddings(request: Request)-> Any:
222
- """转发 embeddings 请求"""
223
- return await _forward_request(request, "/embeddings")
224
-
225
- @app.post("/hf/v1/images/generations")
226
- async def image_generations(request:Request) -> Any:
227
- """转发图片生成请求"""
228
- return await _forward_request(request, "/images/generations")
229
 
230
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.responses import StreamingResponse, HTMLResponse
 
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import httpx
5
  import json
 
7
  import random
8
  from datetime import datetime
9
  import pytz
10
+ from typing import Dict, List, Optional, Any, Callable
11
 
12
+ # --------------------------- 1. 环境变量和配置加载 ---------------------------
13
+ # 提取环境变量,并设置默认值(如果不存在)
14
+ SI_KEYS = os.environ.get("SI_KEY", "").split(",")
15
+ MODEL_MAP_JSON = os.environ.get("MODEL_MAP")
16
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.siliconflow.cn/v1")
17
 
18
+ # 定义模型映射的类型
19
+ ModelMap = Dict[str, str]
20
+ # 定义 Key 余额的类型
21
+ KeyBalance = Dict[str, float]
22
+
23
+
24
+ # 加载模型映射,如果环境变量 MODEL_MAP 存在,则使用它,否则使用默认值
25
+ DEFAULT_MODEL_MAP: ModelMap = {
26
  "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
27
  "qwen-32b": "Qwen/Qwen2.5-32B-Instruct",
28
  "qwen-14b": "Qwen/Qwen2.5-14B-Instruct",
 
37
  "bce": "netease-youdao/bce-embedding-base_v1",
38
  "bge-m3": "BAAI/bge-m3",
39
  "bge-zh": "BAAI/bge-large-zh-v1.5",
40
+ "sd": "stabilityai/stable-diffusion-3-5-large",
41
+ "sd-turbo": "stabilityai/stable-diffusion-3-5-large-turbo",
42
+ "flux-s": "black-forest-labs/FLUX.1-schnell",
43
+ "flux-d": "black-forest-labs/FLUX.1-dev",
44
  }
 
 
 
 
 
 
 
 
 
 
45
 
46
+ model_map: ModelMap = json.loads(MODEL_MAP_JSON) if MODEL_MAP_JSON else DEFAULT_MODEL_MAP
47
+
48
+ # --------------------------- 2. 全局变量初始化 ---------------------------
49
+ keys: List[str] = [key.strip() for key in SI_KEYS if key.strip()] # 删除空白字符、空值
50
+ key_balance: KeyBalance = {}
51
+ key_balance_notes: str = ""
52
+ tz = pytz.timezone("Asia/Shanghai")
53
+ last_updated_time: str = ""
54
+
55
+ # --------------------------- 3. 密钥选择策略 ---------------------------
56
+ def get_api_key() -> str:
57
+ """随机返回一个API密钥."""
58
+ random.shuffle(keys)
59
+ return keys[0] if keys else "" # 添加空列表保护
60
+
61
 
62
+ # --------------------------- 4. FastAPI 应用初始化 ---------------------------
63
  app = FastAPI()
64
+
65
  app.add_middleware(
66
  CORSMiddleware,
67
  allow_origins=["*"],
 
70
  allow_headers=["*"],
71
  )
72
 
73
+
74
+ # --------------------------- 5. 辅助函数 ---------------------------
75
+ def format_key_balance_note(key: str, balance: float) -> str:
76
+ """将 key 和 balance 信息格式化为 HTML 片段."""
77
+ safe_key = f"{key[0:4]}****{key[-4:]}"
78
+ return f"<h2>{safe_key}————{balance}</h2>"
79
+
80
+ async def check_key(client: httpx.AsyncClient, key:str) -> Optional[float]:
81
+ """检查单个密钥是否有效,并返回余额."""
82
+ url = f"{API_BASE_URL}/user/info"
83
+ headers = {"Authorization": f"Bearer {key}"}
84
+ try:
85
+ res = await client.get(url, headers=headers)
86
+ res.raise_for_status() # 确保状态码为200
87
+ balance = res.json()["data"]["balance"]
88
+ return float(balance)
89
+ except httpx.HTTPError as e:
90
+ print(f"Error checking key {key}: {e}")
91
+ return None
92
+
93
+ async def forward_request(
94
+ request: Request,
95
+ url_path: str,
96
+ is_stream: bool = False,
97
+ ) -> Any:
98
+ """通用的请求转发函数."""
99
+ body = await request.json()
100
+ key = get_api_key()
101
+ if not key:
102
+ raise HTTPException(status_code=400, detail="No valid API key available.")
103
+ headers = {"Authorization": f"Bearer {key}"}
104
 
105
+ if "model" in body and body["model"] in model_map:
106
+ body["model"] = model_map[body["model"]]
107
+ if not "stream" in body or not body['stream']
108
+ is_stream=False
109
+ if is_stream:
110
+ async def generate_response():
111
+ async with httpx.AsyncClient() as client:
112
+ async with client.stream(
113
+ "POST", f"{API_BASE_URL}{url_path}", headers=headers, json=body
114
+ ) as response:
115
+ response.raise_for_status()
116
+ async for chunk in response.aiter_bytes():
117
+ if chunk:
118
+ yield chunk
119
+ return StreamingResponse(generate_response(), media_type="text/event-stream")
120
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  async with httpx.AsyncClient() as client:
122
+ response = await client.post(
123
+ f"{API_BASE_URL}{url_path}", headers=headers, json=body
124
+ )
125
+ response.raise_for_status()
126
+ return response.json()
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # --------------------------- 6. API 路由处理 ---------------------------
129
  @app.get("/", response_class=HTMLResponse)
130
+ async def root():
131
+ """返回 HTML 格式的页面,显示模型和密钥信息."""
132
+ models_info = ""
133
+ for key, value in model_map.items():
134
+ models_info += f"<h2>{key}————{value}</h2>"
135
+
136
+ return f"""
137
  <html>
138
  <head>
139
+ <title>API 状态</title>
140
  </head>
141
  <body>
142
+ <h1>有效Key数量: {len(keys)}</h1>
143
  {models_info}
144
+ <h1>最后更新时间:{last_updated_time}</h1>
145
+ {key_balance_notes}
146
  </body>
147
  </html>
148
+ """
149
 
150
 
151
  @app.get("/check")
152
+ async def check():
153
+ """检查 API 密钥的余额,并更新 key_balance、key_balance_notes 和 last_updated_time."""
154
+ global key_balance, key_balance_notes, last_updated_time, keys
155
+ key_balance_notes = ""
156
+ new_keys = []
157
+ key_balance = {}
158
+ async with httpx.AsyncClient() as client:
159
+ for key in keys:
160
+ balance = await check_key(client, key)
161
+ if balance is not None and balance >= 0.1:
162
+ key_balance[key] = balance
163
+ key_balance_notes += format_key_balance_note(key, balance)
164
+ new_keys.append(key)
165
+ keys=new_keys
166
+ last_updated_time = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
167
+ return f"更新成功:{last_updated_time}"
168
+
169
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  @app.post("/hf/v1/chat/completions")
171
+ async def chat_completions(request: Request):
172
+ """转发聊天补全请求,处理流式和非流式响应."""
173
+ return await forward_request(request, "/chat/completions", is_stream=True)
174
 
175
  @app.post("/hf/v1/embeddings")
176
+ async def embeddings(request: Request):
177
+ """转发 embedding 请求."""
178
+ return await forward_request(request, "/embeddings")
 
 
 
 
 
179
 
180
 
181
+ @app.post("/hf/v1/images/generations")
182
+ async def images_generations(request: Request):
183
+ """转发图像生成请求."""
184
+ return await forward_request(request,"/images/generations",)