|
from fastapi import FastAPI, HTTPException, Request, Depends, status |
|
from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse |
|
from .models import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ModelList |
|
from .gemini import GeminiClient, ResponseWrapper |
|
from .utils import handle_gemini_error, protect_from_abuse, APIKeyManager, test_api_key, format_log_message |
|
import os |
|
import json |
|
import asyncio |
|
from typing import Literal |
|
import random |
|
import requests |
|
from datetime import datetime, timedelta |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
import sys |
|
import logging |
|
|
|
logging.getLogger("uvicorn").disabled = True |
|
logging.getLogger("uvicorn.access").disabled = True |
|
|
|
|
|
logger = logging.getLogger("my_logger") |
|
logger.setLevel(logging.DEBUG) |
|
|
|
def translate_error(message: str) -> str: |
|
if "quota exceeded" in message.lower(): |
|
return "API 密钥配额已用尽" |
|
if "invalid argument" in message.lower(): |
|
return "无效参数" |
|
if "internal server error" in message.lower(): |
|
return "服务器内部错误" |
|
if "service unavailable" in message.lower(): |
|
return "服务不可用" |
|
return message |
|
|
|
|
|
def handle_exception(exc_type, exc_value, exc_traceback): |
|
if issubclass(exc_type, KeyboardInterrupt): |
|
sys.excepthook(exc_type, exc_value, exc_traceback) |
|
return |
|
error_message = translate_error(str(exc_value)) |
|
log_msg = format_log_message('ERROR', f"未捕获的异常: %s" % error_message, extra={'status_code': 500, 'error_message': error_message}) |
|
logger.error(log_msg) |
|
|
|
|
|
sys.excepthook = handle_exception |
|
|
|
app = FastAPI() |
|
|
|
|
|
STATS_FILE = "stats.json" |
|
|
|
def load_stats(): |
|
try: |
|
stats_path = os.path.abspath(STATS_FILE) |
|
|
|
os.makedirs(os.path.dirname(stats_path), exist_ok=True) |
|
|
|
|
|
if not os.path.exists(stats_path): |
|
initial_stats = { |
|
"total_calls": 0, |
|
"today_calls": 0, |
|
"total_tokens": 0, |
|
"today_tokens": 0, |
|
"last_reset": datetime.now().isoformat() |
|
} |
|
with open(stats_path, "w") as f: |
|
json.dump(initial_stats, f, indent=2) |
|
logger.info(f"已创建初始统计文件: {stats_path}") |
|
return initial_stats |
|
|
|
|
|
with open(stats_path, "r") as f: |
|
return json.load(f) |
|
except Exception as e: |
|
logger.error(f"加载统计文件失败: {str(e)}") |
|
return { |
|
"total_calls": 0, |
|
"today_calls": 0, |
|
"total_tokens": 0, |
|
"today_tokens": 0, |
|
"last_reset": datetime.now().isoformat() |
|
} |
|
|
|
def save_stats(stats): |
|
try: |
|
stats_path = os.path.abspath(STATS_FILE) |
|
os.makedirs(os.path.dirname(stats_path), exist_ok=True) |
|
with open(stats_path, "w") as f: |
|
json.dump(stats, f, indent=2) |
|
except Exception as e: |
|
logger.error(f"保存统计文件失败: {str(e)}") |
|
|
|
def update_stats(calls=0, tokens=0): |
|
stats = load_stats() |
|
stats["total_calls"] += calls |
|
stats["today_calls"] += calls |
|
stats["total_tokens"] += tokens |
|
stats["today_tokens"] += tokens |
|
save_stats(stats) |
|
|
|
def reset_daily_stats(): |
|
stats = load_stats() |
|
stats["today_calls"] = 0 |
|
stats["today_tokens"] = 0 |
|
stats["last_reset"] = datetime.now().isoformat() |
|
save_stats(stats) |
|
logger.info("每日统计数据已重置") |
|
|
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(reset_daily_stats, 'cron', hour=0, minute=0) |
|
scheduler.start() |
|
|
|
PASSWORD = os.environ.get("PASSWORD", "123") |
|
MAX_REQUESTS_PER_MINUTE = int(os.environ.get("MAX_REQUESTS_PER_MINUTE", "30")) |
|
MAX_REQUESTS_PER_DAY_PER_IP = int( |
|
os.environ.get("MAX_REQUESTS_PER_DAY_PER_IP", "600")) |
|
|
|
RETRY_DELAY = 1 |
|
MAX_RETRY_DELAY = 16 |
|
safety_settings = [ |
|
{ |
|
"category": "HARM_CATEGORY_HARASSMENT", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_HATE_SPEECH", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
{ |
|
"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', |
|
"threshold": 'BLOCK_NONE' |
|
} |
|
] |
|
safety_settings_g2 = [ |
|
{ |
|
"category": "HARM_CATEGORY_HARASSMENT", |
|
"threshold": "OFF" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_HATE_SPEECH", |
|
"threshold": "OFF" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
|
"threshold": "OFF" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
|
"threshold": "OFF" |
|
}, |
|
{ |
|
"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', |
|
"threshold": 'OFF' |
|
} |
|
] |
|
|
|
key_manager = APIKeyManager() |
|
current_api_key = key_manager.get_available_key() |
|
|
|
|
|
def switch_api_key(): |
|
global current_api_key |
|
key = key_manager.get_available_key() |
|
if key: |
|
current_api_key = key |
|
log_msg = format_log_message('INFO', f"API key 替换为 → {current_api_key[:8]}...", extra={'key': current_api_key[:8], 'request_type': 'switch_key'}) |
|
logger.info(log_msg) |
|
else: |
|
log_msg = format_log_message('ERROR', "API key 替换失败,所有API key都已尝试,请重新配置或稍后重试", extra={'key': 'N/A', 'request_type': 'switch_key', 'status_code': 'N/A'}) |
|
logger.error(log_msg) |
|
|
|
|
|
async def check_keys(): |
|
available_keys = [] |
|
for key in key_manager.api_keys: |
|
is_valid = await test_api_key(key) |
|
status_msg = "有效" if is_valid else "无效" |
|
log_msg = format_log_message('INFO', f"API Key {key[:10]}... {status_msg}.") |
|
logger.info(log_msg) |
|
if is_valid: |
|
available_keys.append(key) |
|
if not available_keys: |
|
log_msg = format_log_message('ERROR', "没有可用的 API 密钥!", extra={'key': 'N/A', 'request_type': 'startup', 'status_code': 'N/A'}) |
|
logger.error(log_msg) |
|
return available_keys |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
log_msg = format_log_message('INFO', "Starting Gemini API proxy...") |
|
logger.info(log_msg) |
|
available_keys = await check_keys() |
|
if available_keys: |
|
key_manager.api_keys = available_keys |
|
key_manager._reset_key_stack() |
|
key_manager.show_all_keys() |
|
log_msg = format_log_message('INFO', f"可用 API 密钥数量:{len(key_manager.api_keys)}") |
|
logger.info(log_msg) |
|
|
|
log_msg = format_log_message('INFO', f"最大重试次数设置为:{len(key_manager.api_keys)}") |
|
logger.info(log_msg) |
|
if key_manager.api_keys: |
|
all_models = await GeminiClient.list_available_models(key_manager.api_keys[0]) |
|
GeminiClient.AVAILABLE_MODELS = [model.replace( |
|
"models/", "") for model in all_models] |
|
log_msg = format_log_message('INFO', "Available models loaded.") |
|
logger.info(log_msg) |
|
|
|
@app.get("/v1/models", response_model=ModelList) |
|
def list_models(): |
|
log_msg = format_log_message('INFO', "Received request to list models", extra={'request_type': 'list_models', 'status_code': 200}) |
|
logger.info(log_msg) |
|
return ModelList(data=[{"id": model, "object": "model", "created": 1678888888, "owned_by": "organization-owner"} for model in GeminiClient.AVAILABLE_MODELS]) |
|
|
|
|
|
async def verify_password(request: Request): |
|
if PASSWORD: |
|
auth_header = request.headers.get("Authorization") |
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
raise HTTPException( |
|
status_code=401, detail="Unauthorized: Missing or invalid token") |
|
token = auth_header.split(" ")[1] |
|
if token != PASSWORD: |
|
raise HTTPException( |
|
status_code=401, detail="Unauthorized: Invalid token") |
|
|
|
|
|
async def process_request(chat_request: ChatCompletionRequest, http_request: Request, request_type: Literal['stream', 'non-stream']): |
|
global current_api_key |
|
protect_from_abuse( |
|
http_request, MAX_REQUESTS_PER_MINUTE, MAX_REQUESTS_PER_DAY_PER_IP) |
|
if chat_request.model not in GeminiClient.AVAILABLE_MODELS: |
|
error_msg = "无效的模型" |
|
extra_log = {'request_type': request_type, 'model': chat_request.model, 'status_code': 400, 'error_message': error_msg} |
|
log_msg = format_log_message('ERROR', error_msg, extra=extra_log) |
|
logger.error(log_msg) |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg) |
|
|
|
key_manager.reset_tried_keys_for_request() |
|
|
|
contents, system_instruction = GeminiClient.convert_messages( |
|
GeminiClient, chat_request.messages) |
|
|
|
retry_attempts = len(key_manager.api_keys) if key_manager.api_keys else 1 |
|
for attempt in range(1, retry_attempts + 1): |
|
if attempt == 1: |
|
current_api_key = key_manager.get_available_key() |
|
|
|
if current_api_key is None: |
|
log_msg_no_key = format_log_message('WARNING', "没有可用的 API 密钥,跳过本次尝试", extra={'request_type': request_type, 'model': chat_request.model, 'status_code': 'N/A'}) |
|
logger.warning(log_msg_no_key) |
|
break |
|
|
|
extra_log = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 'N/A', 'error_message': ''} |
|
log_msg = format_log_message('INFO', f"第 {attempt}/{retry_attempts} 次尝试 ... 使用密钥: {current_api_key[:8]}...", extra=extra_log) |
|
logger.info(log_msg) |
|
|
|
gemini_client = GeminiClient(current_api_key) |
|
try: |
|
if chat_request.stream: |
|
async def stream_generator(): |
|
try: |
|
async for chunk in gemini_client.stream_chat(chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction): |
|
formatted_chunk = {"id": "chatcmpl-someid", "object": "chat.completion.chunk", "created": 1234567, |
|
"model": chat_request.model, "choices": [{"delta": {"role": "assistant", "content": chunk}, "index": 0, "finish_reason": None}]} |
|
yield f"data: {json.dumps(formatted_chunk)}\n\n" |
|
yield "data: [DONE]\n\n" |
|
|
|
except asyncio.CancelledError: |
|
extra_log_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端已断开连接'} |
|
log_msg = format_log_message('INFO', "客户端连接已中断", extra=extra_log_cancel) |
|
logger.info(log_msg) |
|
except Exception as e: |
|
error_detail = handle_gemini_error( |
|
e, current_api_key, key_manager) |
|
yield f"data: {json.dumps({'error': {'message': error_detail, 'type': 'gemini_error'}})}\n\n" |
|
return StreamingResponse(stream_generator(), media_type="text/event-stream") |
|
else: |
|
async def run_gemini_completion(): |
|
try: |
|
response_content = await asyncio.to_thread(gemini_client.complete_chat, chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction) |
|
return response_content |
|
except asyncio.CancelledError: |
|
extra_log_gemini_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端断开导致API调用取消'} |
|
log_msg = format_log_message('INFO', "API调用因客户端断开而取消", extra=extra_log_gemini_cancel) |
|
logger.info(log_msg) |
|
raise |
|
|
|
async def check_client_disconnect(): |
|
while True: |
|
if await http_request.is_disconnected(): |
|
extra_log_client_disconnect = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '检测到客户端断开连接'} |
|
log_msg = format_log_message('INFO', "客户端连接已中断,正在取消API请求", extra=extra_log_client_disconnect) |
|
logger.info(log_msg) |
|
return True |
|
await asyncio.sleep(0.5) |
|
|
|
gemini_task = asyncio.create_task(run_gemini_completion()) |
|
disconnect_task = asyncio.create_task(check_client_disconnect()) |
|
|
|
try: |
|
done, pending = await asyncio.wait( |
|
[gemini_task, disconnect_task], |
|
return_when=asyncio.FIRST_COMPLETED |
|
) |
|
|
|
if disconnect_task in done: |
|
gemini_task.cancel() |
|
try: |
|
await gemini_task |
|
except asyncio.CancelledError: |
|
extra_log_gemini_task_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': 'API任务已终止'} |
|
log_msg = format_log_message('INFO', "API任务已成功取消", extra=extra_log_gemini_task_cancel) |
|
logger.info(log_msg) |
|
|
|
raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="客户端连接已中断") |
|
|
|
if gemini_task in done: |
|
disconnect_task.cancel() |
|
try: |
|
await disconnect_task |
|
except asyncio.CancelledError: |
|
pass |
|
response_content = gemini_task.result() |
|
if response_content.text == "": |
|
extra_log_empty_response = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 204} |
|
log_msg = format_log_message('INFO', "Gemini API 返回空响应", extra=extra_log_empty_response) |
|
logger.info(log_msg) |
|
|
|
continue |
|
response = ChatCompletionResponse(id="chatcmpl-someid", object="chat.completion", created=1234567890, model=chat_request.model, |
|
choices=[{"index": 0, "message": {"role": "assistant", "content": response_content.text}, "finish_reason": "stop"}]) |
|
extra_log_success = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 200} |
|
log_msg = format_log_message('INFO', "请求处理成功", extra=extra_log_success) |
|
logger.info(log_msg) |
|
|
|
tokens = response_content.total_token_count or 0 |
|
update_stats(calls=1, tokens=tokens) |
|
return response |
|
|
|
except asyncio.CancelledError: |
|
extra_log_request_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message':"请求被取消" } |
|
log_msg = format_log_message('INFO', "请求取消", extra=extra_log_request_cancel) |
|
logger.info(log_msg) |
|
raise |
|
|
|
except HTTPException as e: |
|
if e.status_code == status.HTTP_408_REQUEST_TIMEOUT: |
|
extra_log = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, |
|
'status_code': 408, 'error_message': '客户端连接中断'} |
|
log_msg = format_log_message('ERROR', "客户端连接中断,终止后续重试", extra=extra_log) |
|
logger.error(log_msg) |
|
raise |
|
else: |
|
raise |
|
except Exception as e: |
|
handle_gemini_error(e, current_api_key, key_manager) |
|
if attempt < retry_attempts: |
|
switch_api_key() |
|
continue |
|
|
|
msg = "所有API密钥均失败,请稍后重试" |
|
extra_log_all_fail = {'key': "ALL", 'request_type': request_type, 'model': chat_request.model, 'status_code': 500, 'error_message': msg} |
|
log_msg = format_log_message('ERROR', msg, extra=extra_log_all_fail) |
|
logger.error(log_msg) |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg) |
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
|
async def chat_completions(request: ChatCompletionRequest, http_request: Request, _: None = Depends(verify_password)): |
|
return await process_request(request, http_request, "stream" if request.stream else "non-stream") |
|
|
|
|
|
@app.exception_handler(Exception) |
|
async def global_exception_handler(request: Request, exc: Exception): |
|
error_message = translate_error(str(exc)) |
|
extra_log_unhandled_exception = {'status_code': 500, 'error_message': error_message} |
|
log_msg = format_log_message('ERROR', f"Unhandled exception: {error_message}", extra=extra_log_unhandled_exception) |
|
logger.error(log_msg) |
|
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ErrorResponse(message=str(exc), type="internal_error").dict()) |
|
|
|
|
|
@app.get("/api/stats") |
|
async def get_stats(): |
|
stats = load_stats() |
|
return { |
|
"today_calls": stats["today_calls"], |
|
"total_calls": stats["total_calls"], |
|
"today_tokens": stats["today_tokens"], |
|
"total_tokens": stats["total_tokens"], |
|
"last_reset": stats["last_reset"] |
|
} |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(): |
|
html_content = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Gemini API 代理服务</title> |
|
<style> |
|
body {{ |
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; |
|
max-width: 800px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
line-height: 1.6; |
|
}} |
|
h1 {{ |
|
color: #333; |
|
text-align: center; |
|
margin-bottom: 30px; |
|
}} |
|
.info-box {{ |
|
background-color: #f8f9fa; |
|
border: 1px solid #dee2e6; |
|
border-radius: 4px; |
|
padding: 20px; |
|
margin-bottom: 20px; |
|
}} |
|
.status {{ |
|
color: #28a745; |
|
font-weight: bold; |
|
}} |
|
</style> |
|
</head> |
|
<body> |
|
<h1>🤖 Gemini API 代理服务</h1> |
|
|
|
<div class="info-box"> |
|
<h2>🟢 运行状态</h2> |
|
<p class="status">服务运行中</p> |
|
<p>可用API密钥数量: {len(key_manager.api_keys)}</p> |
|
<p>可用模型数量: {len(GeminiClient.AVAILABLE_MODELS)}</p> |
|
</div> |
|
|
|
<div class="info-box"> |
|
<h2>⚙️ 环境配置</h2> |
|
<p>每分钟请求限制: {MAX_REQUESTS_PER_MINUTE}</p> |
|
<p>每IP每日请求限制: {MAX_REQUESTS_PER_DAY_PER_IP}</p> |
|
<p>最大重试次数: {len(key_manager.api_keys)}</p> |
|
</div> |
|
|
|
<div class="info-box"> |
|
<h2>📊 使用统计</h2> |
|
<p>今日调用次数: <span id="todayCalls">加载中...</span></p> |
|
<p>累计调用次数: <span id="totalCalls">加载中...</span></p> |
|
<p>今日Token数: <span id="todayTokens">加载中...</span></p> |
|
<p>累计Token数: <span id="totalTokens">加载中...</span></p> |
|
<p>最后重置时间: <span id="lastReset">加载中...</span></p> |
|
</div> |
|
|
|
<script> |
|
async function loadStats() {{ |
|
try {{ |
|
const response = await fetch('/api/stats'); |
|
const data = await response.json(); |
|
|
|
document.getElementById('todayCalls').textContent = data.today_calls; |
|
document.getElementById('totalCalls').textContent = data.total_calls; |
|
document.getElementById('todayTokens').textContent = data.today_tokens; |
|
document.getElementById('totalTokens').textContent = data.total_tokens; |
|
document.getElementById('lastReset').textContent = new Date(data.last_reset).toLocaleString(); |
|
}} catch (error) {{ |
|
console.error('加载统计信息失败:', error); |
|
}} |
|
}} |
|
|
|
// 初始加载 |
|
loadStats(); |
|
// 每10秒刷新一次 |
|
setInterval(loadStats, 10000); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return html_content |
|
|