|
import random
|
|
from fastapi import HTTPException, Request
|
|
import time
|
|
import re
|
|
from datetime import datetime, timedelta
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
import os
|
|
import requests
|
|
import httpx
|
|
from threading import Lock
|
|
import logging
|
|
import sys
|
|
|
|
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
|
|
LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
|
|
LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
|
|
|
|
|
|
logger = logging.getLogger("my_logger")
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
handler = logging.StreamHandler()
|
|
|
|
|
|
logger.addHandler(handler)
|
|
|
|
def format_log_message(level, message, extra=None):
|
|
extra = extra or {}
|
|
log_values = {
|
|
'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
'levelname': level,
|
|
'key': extra.get('key', 'N/A'),
|
|
'request_type': extra.get('request_type', 'N/A'),
|
|
'model': extra.get('model', 'N/A'),
|
|
'status_code': extra.get('status_code', 'N/A'),
|
|
'error_message': extra.get('error_message', ''),
|
|
'message': message
|
|
}
|
|
log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
|
|
return log_format % log_values
|
|
|
|
|
|
class APIKeyManager:
|
|
def __init__(self):
|
|
self.api_keys = re.findall(
|
|
r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', ""))
|
|
self.key_stack = []
|
|
self._reset_key_stack()
|
|
|
|
|
|
self.scheduler = BackgroundScheduler()
|
|
self.scheduler.start()
|
|
self.tried_keys_for_request = set()
|
|
|
|
def _reset_key_stack(self):
|
|
"""创建并随机化密钥栈"""
|
|
shuffled_keys = self.api_keys[:]
|
|
random.shuffle(shuffled_keys)
|
|
self.key_stack = shuffled_keys
|
|
|
|
|
|
def get_available_key(self):
|
|
"""从栈顶获取密钥,栈空时重新生成 (修改后)"""
|
|
while self.key_stack:
|
|
key = self.key_stack.pop()
|
|
|
|
if key not in self.tried_keys_for_request:
|
|
self.tried_keys_for_request.add(key)
|
|
return key
|
|
|
|
if not self.api_keys:
|
|
log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!")
|
|
logger.error(log_msg)
|
|
return None
|
|
|
|
self._reset_key_stack()
|
|
|
|
|
|
while self.key_stack:
|
|
key = self.key_stack.pop()
|
|
|
|
if key not in self.tried_keys_for_request:
|
|
self.tried_keys_for_request.add(key)
|
|
return key
|
|
|
|
return None
|
|
|
|
|
|
def show_all_keys(self):
|
|
log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ")
|
|
logger.info(log_msg)
|
|
for i, api_key in enumerate(self.api_keys):
|
|
log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}")
|
|
logger.info(log_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_tried_keys_for_request(self):
|
|
"""在新的请求尝试时重置已尝试的 key 集合"""
|
|
self.tried_keys_for_request = set()
|
|
|
|
|
|
def handle_gemini_error(error, current_api_key, key_manager) -> str:
|
|
if isinstance(error, requests.exceptions.HTTPError):
|
|
status_code = error.response.status_code
|
|
if status_code == 400:
|
|
try:
|
|
error_data = error.response.json()
|
|
if 'error' in error_data:
|
|
if error_data['error'].get('code') == "invalid_argument":
|
|
error_message = "无效的 API 密钥"
|
|
extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key)
|
|
logger.error(log_msg)
|
|
|
|
|
|
return error_message
|
|
error_message = error_data['error'].get(
|
|
'message', 'Bad Request')
|
|
extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400)
|
|
logger.warning(log_msg)
|
|
return f"400 错误请求: {error_message}"
|
|
except ValueError:
|
|
error_message = "400 错误请求:响应不是有效的JSON格式"
|
|
extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json)
|
|
logger.warning(log_msg)
|
|
return error_message
|
|
|
|
elif status_code == 429:
|
|
error_message = "API 密钥配额已用尽或其他原因"
|
|
extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429)
|
|
logger.warning(log_msg)
|
|
|
|
|
|
return error_message
|
|
|
|
elif status_code == 403:
|
|
error_message = "权限被拒绝"
|
|
extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403)
|
|
logger.error(log_msg)
|
|
|
|
|
|
return error_message
|
|
elif status_code == 500:
|
|
error_message = "服务器内部错误"
|
|
extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500)
|
|
logger.warning(log_msg)
|
|
|
|
return "Gemini API 内部错误"
|
|
|
|
elif status_code == 503:
|
|
error_message = "服务不可用"
|
|
extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503)
|
|
logger.warning(log_msg)
|
|
|
|
return "Gemini API 服务不可用"
|
|
else:
|
|
error_message = f"未知错误: {status_code}"
|
|
extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other)
|
|
logger.warning(log_msg)
|
|
|
|
return f"未知错误/模型不可用: {status_code}"
|
|
|
|
elif isinstance(error, requests.exceptions.ConnectionError):
|
|
error_message = "连接错误"
|
|
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
|
|
logger.warning(log_msg)
|
|
return error_message
|
|
|
|
elif isinstance(error, requests.exceptions.Timeout):
|
|
error_message = "请求超时"
|
|
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
|
|
logger.warning(log_msg)
|
|
return error_message
|
|
else:
|
|
error_message = f"发生未知错误: {error}"
|
|
log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message})
|
|
logger.error(log_msg)
|
|
return error_message
|
|
|
|
|
|
async def test_api_key(api_key: str) -> bool:
|
|
"""
|
|
测试 API 密钥是否有效。
|
|
"""
|
|
try:
|
|
url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key)
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
rate_limit_data = {}
|
|
rate_limit_lock = Lock()
|
|
|
|
|
|
def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600):
|
|
now = int(time.time())
|
|
minute = now // 60
|
|
day = now // (60 * 60 * 24)
|
|
|
|
minute_key = f"{request.url.path}:{minute}"
|
|
day_key = f"{request.client.host}:{day}"
|
|
|
|
with rate_limit_lock:
|
|
minute_count, minute_timestamp = rate_limit_data.get(
|
|
minute_key, (0, now))
|
|
if now - minute_timestamp >= 60:
|
|
minute_count = 0
|
|
minute_timestamp = now
|
|
minute_count += 1
|
|
rate_limit_data[minute_key] = (minute_count, minute_timestamp)
|
|
|
|
day_count, day_timestamp = rate_limit_data.get(day_key, (0, now))
|
|
if now - day_timestamp >= 86400:
|
|
day_count = 0
|
|
day_timestamp = now
|
|
day_count += 1
|
|
rate_limit_data[day_key] = (day_count, day_timestamp)
|
|
|
|
if minute_count > max_requests_per_minute:
|
|
raise HTTPException(status_code=429, detail={
|
|
"message": "Too many requests per minute", "limit": max_requests_per_minute})
|
|
if day_count > max_requests_per_day_per_ip:
|
|
raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip}) |