import random from fastapi import HTTPException, Request import time import re from datetime import datetime, timedelta from apscheduler.schedulers.background import BackgroundScheduler import os import requests import httpx from threading import Lock import logging import sys DEBUG = os.environ.get("DEBUG", "false").lower() == "true" LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s' LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s' # 配置 logger logger = logging.getLogger("my_logger") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() # formatter = logging.Formatter('%(message)s') # handler.setFormatter(formatter) logger.addHandler(handler) def format_log_message(level, message, extra=None): extra = extra or {} log_values = { 'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'levelname': level, 'key': extra.get('key', 'N/A'), 'request_type': extra.get('request_type', 'N/A'), 'model': extra.get('model', 'N/A'), 'status_code': extra.get('status_code', 'N/A'), 'error_message': extra.get('error_message', ''), 'message': message } log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL return log_format % log_values class APIKeyManager: def __init__(self): self.api_keys = re.findall( r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', "")) self.key_stack = [] # 初始化密钥栈 self._reset_key_stack() # 初始化时创建随机密钥栈 # self.api_key_blacklist = set() # self.api_key_blacklist_duration = 60 self.scheduler = BackgroundScheduler() self.scheduler.start() self.tried_keys_for_request = set() # 用于跟踪当前请求尝试中已试过的 key def _reset_key_stack(self): """创建并随机化密钥栈""" shuffled_keys = self.api_keys[:] # 创建 api_keys 的副本以避免直接修改原列表 random.shuffle(shuffled_keys) self.key_stack = shuffled_keys def get_available_key(self): """从栈顶获取密钥,栈空时重新生成 (修改后)""" while self.key_stack: key = self.key_stack.pop() # if key not in self.api_key_blacklist and key not in self.tried_keys_for_request: if key not in self.tried_keys_for_request: self.tried_keys_for_request.add(key) return key if not self.api_keys: log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!") logger.error(log_msg) return None self._reset_key_stack() # 重新生成密钥栈 # 再次尝试从新栈中获取密钥 (迭代一次) while self.key_stack: key = self.key_stack.pop() # if key not in self.api_key_blacklist and key not in self.tried_keys_for_request: if key not in self.tried_keys_for_request: self.tried_keys_for_request.add(key) return key return None def show_all_keys(self): log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ") logger.info(log_msg) for i, api_key in enumerate(self.api_keys): log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}") logger.info(log_msg) # def blacklist_key(self, key): # log_msg = format_log_message('WARNING', f"{key[:8]} → 暂时禁用 {self.api_key_blacklist_duration} 秒") # logger.warning(log_msg) # self.api_key_blacklist.add(key) # self.scheduler.add_job(lambda: self.api_key_blacklist.discard(key), 'date', # run_date=datetime.now() + timedelta(seconds=self.api_key_blacklist_duration)) def reset_tried_keys_for_request(self): """在新的请求尝试时重置已尝试的 key 集合""" self.tried_keys_for_request = set() def handle_gemini_error(error, current_api_key, key_manager) -> str: if isinstance(error, requests.exceptions.HTTPError): status_code = error.response.status_code if status_code == 400: try: error_data = error.response.json() if 'error' in error_data: if error_data['error'].get('code') == "invalid_argument": error_message = "无效的 API 密钥" extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key) logger.error(log_msg) # key_manager.blacklist_key(current_api_key) return error_message error_message = error_data['error'].get( 'message', 'Bad Request') extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400) logger.warning(log_msg) return f"400 错误请求: {error_message}" except ValueError: error_message = "400 错误请求:响应不是有效的JSON格式" extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json) logger.warning(log_msg) return error_message elif status_code == 429: error_message = "API 密钥配额已用尽或其他原因" extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429) logger.warning(log_msg) # key_manager.blacklist_key(current_api_key) return error_message elif status_code == 403: error_message = "权限被拒绝" extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403) logger.error(log_msg) # key_manager.blacklist_key(current_api_key) return error_message elif status_code == 500: error_message = "服务器内部错误" extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500) logger.warning(log_msg) return "Gemini API 内部错误" elif status_code == 503: error_message = "服务不可用" extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503) logger.warning(log_msg) return "Gemini API 服务不可用" else: error_message = f"未知错误: {status_code}" extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other) logger.warning(log_msg) return f"未知错误/模型不可用: {status_code}" elif isinstance(error, requests.exceptions.ConnectionError): error_message = "连接错误" log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message}) logger.warning(log_msg) return error_message elif isinstance(error, requests.exceptions.Timeout): error_message = "请求超时" log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message}) logger.warning(log_msg) return error_message else: error_message = f"发生未知错误: {error}" log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message}) logger.error(log_msg) return error_message async def test_api_key(api_key: str) -> bool: """ 测试 API 密钥是否有效。 """ try: url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key) async with httpx.AsyncClient() as client: response = await client.get(url) response.raise_for_status() return True except Exception: return False rate_limit_data = {} rate_limit_lock = Lock() def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600): now = int(time.time()) minute = now // 60 day = now // (60 * 60 * 24) minute_key = f"{request.url.path}:{minute}" day_key = f"{request.client.host}:{day}" with rate_limit_lock: minute_count, minute_timestamp = rate_limit_data.get( minute_key, (0, now)) if now - minute_timestamp >= 60: minute_count = 0 minute_timestamp = now minute_count += 1 rate_limit_data[minute_key] = (minute_count, minute_timestamp) day_count, day_timestamp = rate_limit_data.get(day_key, (0, now)) if now - day_timestamp >= 86400: day_count = 0 day_timestamp = now day_count += 1 rate_limit_data[day_key] = (day_count, day_timestamp) if minute_count > max_requests_per_minute: raise HTTPException(status_code=429, detail={ "message": "Too many requests per minute", "limit": max_requests_per_minute}) if day_count > max_requests_per_day_per_ip: raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip})