99i commited on
Commit
ec3625c
·
verified ·
1 Parent(s): 71c9caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -30
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, Request, HTTPException
2
- from fastapi.responses import StreamingResponse, HTMLResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import httpx
5
  import json
@@ -13,7 +13,7 @@ import logging
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
  # 从环境变量或者文件中加载配置
16
- BASE_URL = os.environ.get("BASE_URL", "https://api.siliconflow.cn/v1/chat/completions")
17
  MODEL_MAP = {
18
  "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
19
  "qwen-32b": "Qwen/Qwen2.5-32B-Instruct",
@@ -47,6 +47,7 @@ LAST_UPDATE_TIME = ""
47
 
48
  app = FastAPI()
49
 
 
50
  app.add_middleware(
51
  CORSMiddleware,
52
  allow_origins=["*"],
@@ -121,7 +122,6 @@ async def get_model_info():
121
  models += f"<h2>{key}————{value}</h2>"
122
  return models
123
 
124
-
125
  @app.get("/", response_class=HTMLResponse)
126
  async def root():
127
  """根路由,返回 HTML 页面,展示模型信息和更新时间"""
@@ -140,53 +140,96 @@ async def root():
140
  </html>
141
  """
142
 
 
143
  @app.get("/check")
144
  async def check():
145
  """手动触发检查 API Key 余额的路由"""
146
  await key_manager.check_keys_balance()
147
  return f"更新成功:{key_manager.last_update_time}"
148
 
149
- async def _forward_request(request: Request, api_type: str):
150
  """转发请求到硅流 API"""
151
- body = await request.json()
152
  try:
153
- key = key_manager.get_key()
154
  except HTTPException as e:
155
  return e
156
  logging.info(f"using key {key[0:4]}***{key[-4:]} to {api_type}")
157
  headers = {"Authorization": f"Bearer {key}"}
158
- # 处理模型映射
159
- if "model" in body and body["model"] in MODEL_MAP:
160
- body["model"] = MODEL_MAP[body["model"]]
161
-
162
- if api_type == "chat" and "stream" in body and body["stream"]:
163
- async def generate_response():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  async with httpx.AsyncClient() as client:
165
- async with client.stream("POST", BASE_URL, headers=headers, json=body) as response:
166
- response.raise_for_status() # 检查响应状态码
167
- async for chunk in response.aiter_bytes():
168
- if chunk:
169
- yield chunk
170
- return StreamingResponse(generate_response(), media_type="text/event-stream")
171
- else:
172
- async with httpx.AsyncClient() as client:
173
- try:
174
- response = await client.post(BASE_URL, headers=headers, json=body)
175
- response.raise_for_status()
176
- return response.json()
177
- except httpx.HTTPError as exc :
178
- logging.error("httpx request error:" + str(exc))
179
- raise HTTPException(
180
- status_code=500,
181
- detail=f"Request failed with status: {exc.response.status_code},url:{exc.request.url},detail:{exc.response.text}"
182
- )
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  @app.post("/hf/v1/chat/completions")
185
  async def chat_completions(request: Request):
186
  """转发 chat 完成请求"""
187
  return await _forward_request(request, "chat")
 
188
  @app.post("/hf/v1/embeddings")
189
  async def embeddings(request: Request):
190
  """转发 embeddings 请求"""
191
  return await _forward_request(request, "embedding")
192
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import httpx
5
  import json
 
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
  # 从环境变量或者文件中加载配置
16
+ BASE_URL = os.environ.get("BASE_URL", "https://api.siliconflow.cn/v1")
17
  MODEL_MAP = {
18
  "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
19
  "qwen-32b": "Qwen/Qwen2.5-32B-Instruct",
 
47
 
48
  app = FastAPI()
49
 
50
+
51
  app.add_middleware(
52
  CORSMiddleware,
53
  allow_origins=["*"],
 
122
  models += f"<h2>{key}————{value}</h2>"
123
  return models
124
 
 
125
  @app.get("/", response_class=HTMLResponse)
126
  async def root():
127
  """根路由,返回 HTML 页面,展示模型信息和更新时间"""
 
140
  </html>
141
  """
142
 
143
+
144
  @app.get("/check")
145
  async def check():
146
  """手动触发检查 API Key 余额的路由"""
147
  await key_manager.check_keys_balance()
148
  return f"更新成功:{key_manager.last_update_time}"
149
 
150
+ async def _forward_request(request: Request, api_type: str, image_generations:bool = False):
151
  """转发请求到硅流 API"""
152
+
153
  try:
154
+ key = key_manager.get_key()
155
  except HTTPException as e:
156
  return e
157
  logging.info(f"using key {key[0:4]}***{key[-4:]} to {api_type}")
158
  headers = {"Authorization": f"Bearer {key}"}
159
+ if api_type == "chat" and not image_generations :
160
+ body = await request.json()
161
+ # 处理模型映射
162
+ if "model" in body and body["model"] in MODEL_MAP:
163
+ body["model"] = MODEL_MAP[body["model"]]
164
+
165
+ if api_type == "chat" and "stream" in body and body["stream"]:
166
+ async def generate_response():
167
+ async with httpx.AsyncClient() as client:
168
+ async with client.stream("POST", f"{BASE_URL}/chat/completions" , headers=headers, json=body) as response:
169
+ response.raise_for_status() # 检查响应状态码
170
+ async for chunk in response.aiter_bytes():
171
+ if chunk:
172
+ yield chunk
173
+ return StreamingResponse(generate_response(), media_type="text/event-stream")
174
+ else :
175
+ async with httpx.AsyncClient() as client:
176
+ try:
177
+ response = await client.post(f"{BASE_URL}/chat/completions", headers=headers, json=body)
178
+ response.raise_for_status()
179
+ return response.json()
180
+ except httpx.HTTPError as exc :
181
+ logging.error("httpx request error:" + str(exc))
182
+ raise HTTPException(
183
+ status_code=500,
184
+ detail=f"Request failed with status: {exc.response.status_code},url:{exc.request.url},detail:{exc.response.text}"
185
+ )
186
+ elif image_generations:
187
+ url =f"{BASE_URL}/images/generations"
188
+ body = await request.json()
189
+ headers["Content-Type"] = "application/json"
190
  async with httpx.AsyncClient() as client:
191
+
192
+ try:
193
+ response = await client.post(url, headers=headers,json=body)
194
+ response.raise_for_status()
195
+ return JSONResponse(content=response.json(), status_code=response.status_code)
196
+ except httpx.HTTPError as exc :
197
+ logging.error("httpx request error:" + str(exc))
198
+ raise HTTPException(
199
+ status_code=500,
200
+ detail=f"Request failed with status: {exc.response.status_code}, url:{exc.request.url} ,detail:{exc.response.text}"
201
+ )
202
+ elif api_type =="embedding":
203
+ url=f"{BASE_URL}/embeddings"
204
+ body=await request.json()
205
+ if "model" in body and body["model"] in MODEL_MAP:
206
+ body["model"] = MODEL_MAP[body["model"]]
207
+ async with httpx.AsyncClient() as client:
208
+ try:
209
+ response = await client.post(url, json=body,headers=headers)
210
+ response.raise_for_status()
211
+ return response.json()
212
+ except httpx.HTTPError as exc :
213
+ logging.error(f"httpx request embedding error :{exc}")
214
+ raise HTTPException(
215
+ status_code=500,
216
+ detail=f"Request failed with status: {exc.response.status_code}, url:{exc.request.url}, detail:{exc.response.text}"
217
+ )
218
+
219
+
220
 
221
  @app.post("/hf/v1/chat/completions")
222
  async def chat_completions(request: Request):
223
  """转发 chat 完成请求"""
224
  return await _forward_request(request, "chat")
225
+
226
  @app.post("/hf/v1/embeddings")
227
  async def embeddings(request: Request):
228
  """转发 embeddings 请求"""
229
  return await _forward_request(request, "embedding")
230
 
231
+ @app.post("/hf/v1/images/generations")
232
+ async def image_generations(request:Request):
233
+ """转发图片生成请求"""
234
+ return await _forward_request(request, api_type="image",image_generations=True)
235
+