gemini / app /utils.py
FISHYA's picture
Upload 10 files
5b92ade verified
import random
from fastapi import HTTPException, Request
import time
import re
from datetime import datetime, timedelta
from apscheduler.schedulers.background import BackgroundScheduler
import os
import requests
import httpx
from threading import Lock
import logging
import sys
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
# 配置 logger
logger = logging.getLogger("my_logger")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
# formatter = logging.Formatter('%(message)s')
# handler.setFormatter(formatter)
logger.addHandler(handler)
def format_log_message(level, message, extra=None):
extra = extra or {}
log_values = {
'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'levelname': level,
'key': extra.get('key', 'N/A'),
'request_type': extra.get('request_type', 'N/A'),
'model': extra.get('model', 'N/A'),
'status_code': extra.get('status_code', 'N/A'),
'error_message': extra.get('error_message', ''),
'message': message
}
log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
return log_format % log_values
class APIKeyManager:
def __init__(self):
self.api_keys = re.findall(
r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', ""))
self.key_stack = [] # 初始化密钥栈
self._reset_key_stack() # 初始化时创建随机密钥栈
# self.api_key_blacklist = set()
# self.api_key_blacklist_duration = 60
self.scheduler = BackgroundScheduler()
self.scheduler.start()
self.tried_keys_for_request = set() # 用于跟踪当前请求尝试中已试过的 key
def _reset_key_stack(self):
"""创建并随机化密钥栈"""
shuffled_keys = self.api_keys[:] # 创建 api_keys 的副本以避免直接修改原列表
random.shuffle(shuffled_keys)
self.key_stack = shuffled_keys
def get_available_key(self):
"""从栈顶获取密钥,栈空时重新生成 (修改后)"""
while self.key_stack:
key = self.key_stack.pop()
# if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
if key not in self.tried_keys_for_request:
self.tried_keys_for_request.add(key)
return key
if not self.api_keys:
log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!")
logger.error(log_msg)
return None
self._reset_key_stack() # 重新生成密钥栈
# 再次尝试从新栈中获取密钥 (迭代一次)
while self.key_stack:
key = self.key_stack.pop()
# if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
if key not in self.tried_keys_for_request:
self.tried_keys_for_request.add(key)
return key
return None
def show_all_keys(self):
log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ")
logger.info(log_msg)
for i, api_key in enumerate(self.api_keys):
log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}")
logger.info(log_msg)
# def blacklist_key(self, key):
# log_msg = format_log_message('WARNING', f"{key[:8]} → 暂时禁用 {self.api_key_blacklist_duration} 秒")
# logger.warning(log_msg)
# self.api_key_blacklist.add(key)
# self.scheduler.add_job(lambda: self.api_key_blacklist.discard(key), 'date',
# run_date=datetime.now() + timedelta(seconds=self.api_key_blacklist_duration))
def reset_tried_keys_for_request(self):
"""在新的请求尝试时重置已尝试的 key 集合"""
self.tried_keys_for_request = set()
def handle_gemini_error(error, current_api_key, key_manager) -> str:
if isinstance(error, requests.exceptions.HTTPError):
status_code = error.response.status_code
if status_code == 400:
try:
error_data = error.response.json()
if 'error' in error_data:
if error_data['error'].get('code') == "invalid_argument":
error_message = "无效的 API 密钥"
extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key)
logger.error(log_msg)
# key_manager.blacklist_key(current_api_key)
return error_message
error_message = error_data['error'].get(
'message', 'Bad Request')
extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400)
logger.warning(log_msg)
return f"400 错误请求: {error_message}"
except ValueError:
error_message = "400 错误请求:响应不是有效的JSON格式"
extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json)
logger.warning(log_msg)
return error_message
elif status_code == 429:
error_message = "API 密钥配额已用尽或其他原因"
extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429)
logger.warning(log_msg)
# key_manager.blacklist_key(current_api_key)
return error_message
elif status_code == 403:
error_message = "权限被拒绝"
extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403)
logger.error(log_msg)
# key_manager.blacklist_key(current_api_key)
return error_message
elif status_code == 500:
error_message = "服务器内部错误"
extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500)
logger.warning(log_msg)
return "Gemini API 内部错误"
elif status_code == 503:
error_message = "服务不可用"
extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503)
logger.warning(log_msg)
return "Gemini API 服务不可用"
else:
error_message = f"未知错误: {status_code}"
extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]}{status_code} 未知错误", extra=extra_log_other)
logger.warning(log_msg)
return f"未知错误/模型不可用: {status_code}"
elif isinstance(error, requests.exceptions.ConnectionError):
error_message = "连接错误"
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
logger.warning(log_msg)
return error_message
elif isinstance(error, requests.exceptions.Timeout):
error_message = "请求超时"
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
logger.warning(log_msg)
return error_message
else:
error_message = f"发生未知错误: {error}"
log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message})
logger.error(log_msg)
return error_message
async def test_api_key(api_key: str) -> bool:
"""
测试 API 密钥是否有效。
"""
try:
url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key)
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
return True
except Exception:
return False
rate_limit_data = {}
rate_limit_lock = Lock()
def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600):
now = int(time.time())
minute = now // 60
day = now // (60 * 60 * 24)
minute_key = f"{request.url.path}:{minute}"
day_key = f"{request.client.host}:{day}"
with rate_limit_lock:
minute_count, minute_timestamp = rate_limit_data.get(
minute_key, (0, now))
if now - minute_timestamp >= 60:
minute_count = 0
minute_timestamp = now
minute_count += 1
rate_limit_data[minute_key] = (minute_count, minute_timestamp)
day_count, day_timestamp = rate_limit_data.get(day_key, (0, now))
if now - day_timestamp >= 86400:
day_count = 0
day_timestamp = now
day_count += 1
rate_limit_data[day_key] = (day_count, day_timestamp)
if minute_count > max_requests_per_minute:
raise HTTPException(status_code=429, detail={
"message": "Too many requests per minute", "limit": max_requests_per_minute})
if day_count > max_requests_per_day_per_ip:
raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip})