99i commited on
Commit
f03905f
·
verified ·
1 Parent(s): 87ab3df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -89
app.py CHANGED
@@ -1,4 +1,6 @@
1
- from fastapi import FastAPI, Request, HTTPException
 
 
2
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import httpx
@@ -12,9 +14,11 @@ import logging
12
  # 配置日志
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
- # 从环境变量或者文件中加载配置
16
- BASE_URL = os.environ.get("BASE_URL", "https://api.siliconflow.cn/v1")
17
- MODEL_MAP = {
 
 
18
  "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
19
  "qwen-32b": "Qwen/Qwen2.5-32B-Instruct",
20
  "qwen-14b": "Qwen/Qwen2.5-14B-Instruct",
@@ -34,24 +38,20 @@ MODEL_MAP = {
34
  "flux-s":"black-forest-labs/FLUX.1-schnell",
35
  "flux-d":"black-forest-labs/FLUX.1-dev"
36
  }
37
- if os.environ.get("MODEL_MAP"):
38
- MODEL_MAP = json.loads(os.environ.get("MODEL_MAP"))
39
-
40
- KEY_STR = os.environ.get("SI_KEY")
41
- if KEY_STR is None:
42
- logging.error("SI_KEY not found in env")
43
- raise EnvironmentError("SI_KEY not found in env")
 
 
 
44
 
45
- KEYS = KEY_STR.split(",")
46
- KEY_BALANCE = {}
47
- KEY_BALANCE_NOTES = ""
48
- # 创建一个东八区的时区对象
49
- TIMEZONE = pytz.timezone("Asia/Shanghai")
50
- LAST_UPDATE_TIME = ""
51
 
52
  app = FastAPI()
53
-
54
-
55
  app.add_middleware(
56
  CORSMiddleware,
57
  allow_origins=["*"],
@@ -62,16 +62,16 @@ app.add_middleware(
62
 
63
  # API Key 管理类
64
  class ApiKeyManager:
65
- def __init__(self, keys):
66
  self.keys = keys
67
- self.key_balance = {}
68
  self.key_balance_notes = ""
69
  self.last_update_time = ""
70
 
71
  def get_key(self) -> str:
72
  """随机获取一个可用 API Key"""
73
  if not self.keys:
74
- raise HTTPException(status_code=500, detail="No available API keys")
75
  random.shuffle(self.keys)
76
  return self.keys[0]
77
 
@@ -104,7 +104,7 @@ class ApiKeyManager:
104
  key_to_remove.append(key)
105
  for remove_key in key_to_remove:
106
  self.remove_key(remove_key)
107
- self.last_update_time = datetime.now(TIMEZONE)
108
 
109
  async def _fetch_balance(self, key: str) -> float:
110
  """发送 API 请求,获取 API Key 的余额"""
@@ -118,19 +118,21 @@ class ApiKeyManager:
118
  return float(balance)
119
  except httpx.HTTPError as exc:
120
  logging.error("httpx request error, detail:" + str(exc))
121
- raise HTTPException(status_code=500, detail=f"Check balance failed with status:{exc.response.status_code},url:{exc.request.url}")
122
- key_manager = ApiKeyManager(KEYS)
123
- async def get_model_info():
 
 
124
  models = ""
125
- for key, value in MODEL_MAP.items():
126
  models += f"<h2>{key}————{value}</h2>"
127
  return models
128
 
129
  @app.get("/", response_class=HTMLResponse)
130
- async def root():
131
  """根路由,返回 HTML 页面,展示模型信息和更新时间"""
132
  models_info = await get_model_info()
133
- return f"""
134
  <html>
135
  <head>
136
  <title>富文本示例</title>
@@ -142,98 +144,87 @@ async def root():
142
  {key_manager.key_balance_notes}
143
  </body>
144
  </html>
145
- """
146
 
147
 
148
  @app.get("/check")
149
- async def check():
150
  """手动触发检查 API Key 余额的路由"""
151
  await key_manager.check_keys_balance()
152
  return f"更新成功:{key_manager.last_update_time}"
153
 
154
- async def _forward_request(request: Request, api_type: str, image_generations:bool = False):
 
 
 
 
 
 
 
 
 
 
 
155
  """转发请求到硅流 API"""
156
-
157
  try:
158
  key = key_manager.get_key()
159
  except HTTPException as e:
160
  return e
161
- logging.info(f"using key {key[0:4]}***{key[-4:]} to {api_type}")
162
  headers = {"Authorization": f"Bearer {key}"}
163
- if api_type == "chat" and not image_generations :
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  body = await request.json()
165
- # 处理模型映射
166
- if "model" in body and body["model"] in MODEL_MAP:
167
- body["model"] = MODEL_MAP[body["model"]]
168
-
169
- if api_type == "chat" and "stream" in body and body["stream"]:
170
- async def generate_response():
171
  async with httpx.AsyncClient() as client:
172
- async with client.stream("POST", f"{BASE_URL}/chat/completions" , headers=headers, json=body) as response:
173
  response.raise_for_status() # 检查响应状态码
174
  async for chunk in response.aiter_bytes():
175
  if chunk:
176
  yield chunk
177
- return StreamingResponse(generate_response(), media_type="text/event-stream")
178
- else :
179
- async with httpx.AsyncClient() as client:
 
 
180
  try:
181
- response = await client.post(f"{BASE_URL}/chat/completions", headers=headers, json=body)
182
  response.raise_for_status()
183
  return response.json()
184
  except httpx.HTTPError as exc :
185
- logging.error("httpx request error:" + str(exc))
186
  raise HTTPException(
187
- status_code=500,
188
  detail=f"Request failed with status: {exc.response.status_code},url:{exc.request.url},detail:{exc.response.text}"
189
  )
190
- elif image_generations:
191
- url =f"{BASE_URL}/images/generations"
192
- body = await request.json()
193
- headers["Content-Type"] = "application/json"
194
- async with httpx.AsyncClient() as client:
195
-
196
- try:
197
- response = await client.post(url, headers=headers,json=body)
198
- response.raise_for_status()
199
- return JSONResponse(content=response.json(), status_code=response.status_code)
200
- except httpx.HTTPError as exc :
201
- logging.error("httpx request error:" + str(exc))
202
- raise HTTPException(
203
- status_code=500,
204
- detail=f"Request failed with status: {exc.response.status_code}, url:{exc.request.url} ,detail:{exc.response.text}"
205
- )
206
- elif api_type =="embedding":
207
- url=f"{BASE_URL}/embeddings"
208
- body=await request.json()
209
- if "model" in body and body["model"] in MODEL_MAP:
210
- body["model"] = MODEL_MAP[body["model"]]
211
- async with httpx.AsyncClient() as client:
212
- try:
213
- response = await client.post(url, json=body,headers=headers)
214
- response.raise_for_status()
215
- return response.json()
216
- except httpx.HTTPError as exc :
217
- logging.error(f"httpx request embedding error :{exc}")
218
- raise HTTPException(
219
- status_code=500,
220
- detail=f"Request failed with status: {exc.response.status_code}, url:{exc.request.url}, detail:{exc.response.text}"
221
- )
222
-
223
-
224
-
225
  @app.post("/hf/v1/chat/completions")
226
- async def chat_completions(request: Request):
227
  """转发 chat 完成请求"""
228
- return await _forward_request(request, "chat")
229
 
230
  @app.post("/hf/v1/embeddings")
231
- async def embeddings(request: Request):
232
  """转发 embeddings 请求"""
233
- return await _forward_request(request, "embedding")
234
 
235
  @app.post("/hf/v1/images/generations")
236
- async def image_generations(request:Request):
237
  """转发图片生成请求"""
238
- return await _forward_request(request, api_type="image",image_generations=True)
 
239
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional
3
+ from fastapi import FastAPI, Request, HTTPException, status
4
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import httpx
 
14
  # 配置日志
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
 
17
+ @dataclass
18
+ class ApiConfig:
19
+ """API 配置数据类"""
20
+ base_url: str = os.environ.get("BASE_URL", "https://api.siliconflow.cn/v1")
21
+ model_map: Dict[str, str] = {
22
  "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
23
  "qwen-32b": "Qwen/Qwen2.5-32B-Instruct",
24
  "qwen-14b": "Qwen/Qwen2.5-14B-Instruct",
 
38
  "flux-s":"black-forest-labs/FLUX.1-schnell",
39
  "flux-d":"black-forest-labs/FLUX.1-dev"
40
  }
41
+ api_key_str: Optional[str] = os.environ.get("SI_KEY")
42
+ timezone: str = "Asia/Shanghai"
43
+
44
+ def __post_init__(self):
45
+ if not self.api_key_str :
46
+ logging.error("SI_KEY not found in env")
47
+ raise EnvironmentError("SI_KEY not found in env")
48
+ self.api_keys:list[str] = self.api_key_str.split(",")
49
+ if os.environ.get("MODEL_MAP"):
50
+ self.model_map = json.loads(os.environ.get("MODEL_MAP"))
51
 
52
+ config = ApiConfig()
 
 
 
 
 
53
 
54
  app = FastAPI()
 
 
55
  app.add_middleware(
56
  CORSMiddleware,
57
  allow_origins=["*"],
 
62
 
63
  # API Key 管理类
64
  class ApiKeyManager:
65
+ def __init__(self, keys:list[str]):
66
  self.keys = keys
67
+ self.key_balance :Dict[str,float] = {}
68
  self.key_balance_notes = ""
69
  self.last_update_time = ""
70
 
71
  def get_key(self) -> str:
72
  """随机获取一个可用 API Key"""
73
  if not self.keys:
74
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="No available API keys")
75
  random.shuffle(self.keys)
76
  return self.keys[0]
77
 
 
104
  key_to_remove.append(key)
105
  for remove_key in key_to_remove:
106
  self.remove_key(remove_key)
107
+ self.last_update_time = datetime.now(pytz.timezone(config.timezone))
108
 
109
  async def _fetch_balance(self, key: str) -> float:
110
  """发送 API 请求,获取 API Key 的余额"""
 
118
  return float(balance)
119
  except httpx.HTTPError as exc:
120
  logging.error("httpx request error, detail:" + str(exc))
121
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Check balance failed with status:{exc.response.status_code},url:{exc.request.url}")
122
+ key_manager = ApiKeyManager(config.api_keys)
123
+
124
+
125
+ async def get_model_info() -> str:
126
  models = ""
127
+ for key, value in config.model_map.items():
128
  models += f"<h2>{key}————{value}</h2>"
129
  return models
130
 
131
  @app.get("/", response_class=HTMLResponse)
132
+ async def root() -> HTMLResponse:
133
  """根路由,返回 HTML 页面,展示模型信息和更新时间"""
134
  models_info = await get_model_info()
135
+ return HTMLResponse(f"""
136
  <html>
137
  <head>
138
  <title>富文本示例</title>
 
144
  {key_manager.key_balance_notes}
145
  </body>
146
  </html>
147
+ """)
148
 
149
 
150
  @app.get("/check")
151
+ async def check() -> str:
152
  """手动触发检查 API Key 余额的路由"""
153
  await key_manager.check_keys_balance()
154
  return f"更新成功:{key_manager.last_update_time}"
155
 
156
+ def is_chat_stream_request(request:Request) -> bool:
157
+ headers = request.headers
158
+ if headers.get("content-type") == 'application/json':
159
+ try:
160
+ request_body=json.loads(request.headers._list[4][1])
161
+ return request_body.get('model') is not None and request_body.get("stream") is True
162
+ except Exception as e:
163
+ logging.error("parse request body error detail:" + str(e))
164
+ return False
165
+ return False
166
+
167
+ async def _forward_request(request: Request, api_endpoint: str, is_stream:bool = False ) -> Any:
168
  """转发请求到硅流 API"""
 
169
  try:
170
  key = key_manager.get_key()
171
  except HTTPException as e:
172
  return e
173
+ logging.info(f"using key {key[0:4]}***{key[-4:]} to {api_endpoint}")
174
  headers = {"Authorization": f"Bearer {key}"}
175
+ url=f"{config.base_url}{api_endpoint}"
176
+ if api_endpoint == "/embeddings":
177
+ body=await request.json()
178
+ if "model" in body and body["model"] in config.model_map:
179
+ body["model"] = config.model_map[body["model"]]
180
+ async with httpx.AsyncClient() as client:
181
+ try:
182
+ response = await client.post(url,headers=headers, json=body)
183
+ response.raise_for_status()
184
+ return response.json()
185
+ except httpx.HTTPError as exc :
186
+ logging.error(f"httpx request embedding error :{exc}")
187
+ raise HTTPException(
188
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
189
+ detail=f"Request failed with status: {exc.response.status_code}, url:{exc.request.url}, detail:{exc.response.text}"
190
+ )
191
+ elif is_stream:
192
  body = await request.json()
193
+ async def generate_response():
 
 
 
 
 
194
  async with httpx.AsyncClient() as client:
195
+ async with client.stream("POST",url, headers=headers, json=body) as response:
196
  response.raise_for_status() # 检查响应状态码
197
  async for chunk in response.aiter_bytes():
198
  if chunk:
199
  yield chunk
200
+ return StreamingResponse(generate_response(), media_type="text/event-stream")
201
+
202
+ else :
203
+ body = await request.json()
204
+ async with httpx.AsyncClient() as client:
205
  try:
206
+ response = await client.post(url, headers=headers, json=body)
207
  response.raise_for_status()
208
  return response.json()
209
  except httpx.HTTPError as exc :
210
+ logging.error(f"httpx request error:{exc}")
211
  raise HTTPException(
212
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
213
  detail=f"Request failed with status: {exc.response.status_code},url:{exc.request.url},detail:{exc.response.text}"
214
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  @app.post("/hf/v1/chat/completions")
216
+ async def chat_completions(request: Request) -> Any:
217
  """转发 chat 完成请求"""
218
+ return await _forward_request(request, "/chat/completions",is_chat_stream_request(request))
219
 
220
  @app.post("/hf/v1/embeddings")
221
+ async def embeddings(request: Request)-> Any:
222
  """转发 embeddings 请求"""
223
+ return await _forward_request(request, "/embeddings")
224
 
225
  @app.post("/hf/v1/images/generations")
226
+ async def image_generations(request:Request) -> Any:
227
  """转发图片生成请求"""
228
+ return await _forward_request(request, "/images/generations")
229
+
230