seekr / src /utils /api_key_manager.py
Hemang Thakur
fixed chat anthropic
c8abe84
from threading import Lock
import os
from typing import List, Optional, Literal, Union, Dict
from dotenv import load_dotenv
import re
from langchain_xai import ChatXAI
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from functools import wraps
import time
from openai import RateLimitError, OpenAIError
from anthropic import RateLimitError as AnthropicRateLimitError, APIError as AnthropicAPIError
from google.api_core.exceptions import ResourceExhausted, BadRequest, InvalidArgument
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type
import asyncio
ModelProvider = Literal["openai", "anthropic", "google", "xai"]
class APIKeyManager:
_instance = None
_lock = Lock()
# Define supported models
SUPPORTED_MODELS = {
"openai": [
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo",
"o1-mini-2024-09-12",
"o1-mini",
"o1-preview-2024-09-12",
"o1-preview",
"o1",
"gpt-4o-mini-2024-07-18",
"gpt-4o-mini",
"chatgpt-4o-latest",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
"gpt-4o"
],
"google": [
"gemini-1.5-flash",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-001",
"gemini-1.5-flash-002",
"gemini-1.5-flash-8b-exp-0924",
"gemini-1.5-flash-8b-exp-0827",
"gemini-1.5-flash-8b-001",
"gemini-1.5-flash-8b",
"gemini-1.5-flash-8b-latest",
"gemini-1.5-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
"gemini-1.5-pro-exp-0827",
"gemini-1.0-pro",
"gemini-1.0-pro-latest",
"gemini-1.0-pro-001",
"gemini-pro",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-2.0-pro-exp-02-05",
"gemini-2.0-flash-lite-preview-02-05",
"gemini-2.0-flash-exp",
"gemini-2.0-flash",
"gemini-2.0-flash-thinking-exp-1219",
],
"xai": [
"grok-beta",
"grok-vision-beta",
"grok-2-vision-1212",
"grok-2-1212"
],
"anthropic": [
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
"claude-3-5-haiku-20241022",
"claude-3-5-haiku-latest",
"claude-3-opus-20240229",
"claude-3-opus-latest",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
]
}
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super(APIKeyManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self._initialized = True
# 1) Always load env
load_dotenv(override=True)
self._current_indices = {
"openai": 0,
"anthropic": 0,
"google": 0,
"xai": 0
}
self._lock = Lock()
# 2) load all provider keys from environment
self._api_keys = self._load_api_keys()
self._llm = None
self._current_provider = None
# 3) read user’s chosen provider, model, temperature, top_p from env
provider_env = os.getenv("MODEL_PROVIDER", "openai").strip().lower()
self.model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo").strip()
temp_str = os.getenv("MODEL_TEMPERATURE", "0")
topp_str = os.getenv("MODEL_TOP_P", "1")
try:
self.temperature = float(temp_str)
except ValueError:
self.temperature = 0.0
try:
self.top_p = float(topp_str)
except ValueError:
self.top_p = 1.0
def _reinit(self):
self._initialized = False
self.__init__()
def _load_api_keys(self) -> Dict[str, List[str]]:
"""Load API keys from environment variables dynamically."""
api_keys = {
"openai": [],
"anthropic": [],
"google": [],
"xai": []
}
# Get all environment variables
env_vars = dict(os.environ)
# Load OpenAI API keys
openai_pattern = re.compile(r'OPENAI_API_KEY_\d+$')
openai_keys = {k: v for k, v in env_vars.items() if openai_pattern.match(k) and v.strip()}
if not openai_keys:
default_key = os.getenv('OPENAI_API_KEY')
if default_key and default_key.strip():
api_keys["openai"].append(default_key)
else:
sorted_keys = sorted(openai_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = openai_keys[key_name]
if api_key and api_key.strip():
api_keys["openai"].append(api_key)
# Load Google API keys
google_pattern = re.compile(r'GOOGLE_API_KEY_\d+$')
google_keys = {k: v for k, v in env_vars.items() if google_pattern.match(k) and v.strip()}
if not google_keys:
default_key = os.getenv('GOOGLE_API_KEY')
if default_key and default_key.strip():
api_keys["google"].append(default_key)
else:
sorted_keys = sorted(google_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = google_keys[key_name]
if api_key and api_key.strip():
api_keys["google"].append(api_key)
# Load XAI API keys
xai_pattern = re.compile(r'XAI_API_KEY_\d+$')
xai_keys = {k: v for k, v in env_vars.items() if xai_pattern.match(k) and v.strip()}
if not xai_keys:
default_key = os.getenv('XAI_API_KEY')
if default_key and default_key.strip():
api_keys["xai"].append(default_key)
else:
sorted_keys = sorted(xai_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = xai_keys[key_name]
if api_key and api_key.strip():
api_keys["xai"].append(api_key)
# Load Anthropic API keys
anthropic_pattern = re.compile(r'ANTHROPIC_API_KEY_\d+$')
anthropic_keys = {k: v for k, v in env_vars.items() if anthropic_pattern.match(k) and v.strip()}
if not anthropic_keys:
default_key = os.getenv('ANTHROPIC_API_KEY')
if default_key and default_key.strip():
api_keys["anthropic"].append(default_key)
else:
sorted_keys = sorted(anthropic_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = anthropic_keys[key_name]
if api_key and api_key.strip():
api_keys["anthropic"].append(api_key)
if not any(api_keys.values()):
raise Exception("No valid API keys found in environment variables")
for provider, keys in api_keys.items():
if keys:
print(f"Loaded {len(keys)} {provider} API keys for rotation")
return api_keys
def get_next_api_key(self, provider: ModelProvider) -> str:
"""Get the next API key in round-robin fashion for the specified provider."""
with self._lock:
if not self._api_keys.get(provider) or len(self._api_keys[provider]) == 0:
raise Exception(f"No API key found for {provider}")
if provider not in self._current_indices:
self._current_indices[provider] = 0
current_key = self._api_keys[provider][self._current_indices[provider]]
self._current_indices[provider] = (self._current_indices[provider] + 1) % len(self._api_keys[provider])
return current_key
def _get_provider_for_model(self) -> ModelProvider:
"""Determine the provider based on the model name."""
load_dotenv(override=True) # to refresh in case .env changed
provider_env = os.getenv("MODEL_PROVIDER", "openai").lower().strip()
if provider_env not in self.SUPPORTED_MODELS:
raise Exception(
f"Invalid or missing MODEL_PROVIDER in env: '{provider_env}'. "
f"Must be one of: {list(self.SUPPORTED_MODELS.keys())}"
)
# check if user-chosen model is in that provider’s list
if self.model_name not in self.SUPPORTED_MODELS[provider_env]:
available = self.SUPPORTED_MODELS[provider_env]
raise Exception(
f"Model '{self.model_name}' is not available under provider '{provider_env}'. "
f"Available: {available}"
)
return provider_env
def _initialize_llm(
self,
model_name: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
streaming: bool = False
):
"""Initialize LLM with the next API key in rotation."""
load_dotenv(override=True) # refresh .env in case it changed
provider = self._get_provider_for_model()
model_name = model_name if model_name else self.model_name
temperature = temperature if temperature else self.temperature
top_p = top_p if top_p else self.top_p
api_key = self.get_next_api_key(provider)
print(f"Using provider={provider}, model_name={model_name}, "
f"temperature={temperature}, top_p={top_p}, key={api_key}")
kwargs = {
"model": model_name,
"temperature": temperature,
"top_p": top_p,
"max_retries": 0,
"streaming": streaming,
"api_key": api_key,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if provider == "openai":
self._llm = ChatOpenAI(**kwargs)
elif provider == "google":
self._llm = ChatGoogleGenerativeAI(**kwargs)
elif provider == "anthropic":
self._llm = ChatAnthropic(**kwargs)
else:
self._llm = ChatXAI(**kwargs)
self._current_provider = provider
def get_llm(
self,
model_name: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
streaming: bool = False
) -> Union[ChatOpenAI, ChatGoogleGenerativeAI, ChatAnthropic, ChatXAI]:
"""Get LLM instance with the current API key."""
provider = self._get_provider_for_model()
model_name = model_name if model_name else self.model_name
temperature = temperature if temperature else self.temperature
top_p = top_p if top_p else self.top_p
if self._llm is None or provider != self._current_provider:
self._initialize_llm(model_name, temperature, top_p, max_tokens, streaming)
return self._llm
def rotate_key(self, provider: Optional[ModelProvider] = None, streaming: bool = False) -> None:
"""Manually rotate to the next API key."""
if provider is None:
provider = self._current_provider
self._initialize_llm(streaming=streaming)
def get_all_api_keys(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, List[str]], List[str]]:
"""Get all available API keys."""
if provider:
return self._api_keys[provider].copy()
return {k: v.copy() for k, v in self._api_keys.items()}
def get_key_count(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, int], int]:
"""Get the total number of available API keys."""
if provider:
return len(self._api_keys[provider])
return {k: len(v) for k, v in self._api_keys.items()}
def __len__(self) -> Dict[str, int]:
"""Get the number of active API keys for each provider."""
return self.get_key_count()
def __bool__(self) -> bool:
"""Check if there are any API keys available."""
return any(bool(keys) for keys in self._api_keys.values())
def with_api_manager(
model_name: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
streaming: bool = False,
delay_on_timeout: int = 20,
max_token_reduction_attempts: int = 0
):
"""Decorator for automatic key rotation on error with delay on timeout."""
manager = APIKeyManager()
provider = manager._get_provider_for_model()
model_name = model_name if model_name else manager.model_name
temperature = temperature if temperature else manager.temperature
top_p = top_p if top_p else manager.top_p
key_count = manager.get_key_count(provider)
def decorator(func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if key_count > 1:
all_keys = manager.get_all_api_keys(provider)
tried_keys = set()
current_max_tokens = max_tokens
token_reduction_attempts = 0
while len(tried_keys) < len(all_keys):
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens,
streaming=streaming
)
result = await func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e:
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}")
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key")
else:
if delay_on_timeout > 0:
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...")
time.sleep(delay_on_timeout)
manager._current_indices[provider] = 0
else:
print(f"All {provider} API keys failed due to rate limits: {str(e)}")
raise
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
error_str = str(e)
if "token" in error_str.lower() or "context length" in error_str.lower():
print(f"Token limit error encountered: {error_str}")
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts:
current_max_tokens = int(current_max_tokens * 0.8) # Reduce the local variable
token_reduction_attempts += 1
print(f"Retrying with reduced max_tokens: {current_max_tokens}")
continue # Retry with reduced max_tokens
else:
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.")
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key after token limit error.")
else:
raise # All keys tried, raise the token limit error
else:
# Re-raise other API errors
raise
# Attempt one final time after trying all keys (for rate limits with delay)
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens, # Use the current value
streaming=streaming
)
result = await func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error after retrying all {provider} API keys: {str(e)}")
raise
elif key_count == 1:
@retry(
wait=wait_random_exponential(min=10, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type((
RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument))
)
async def attempt_function_call():
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
streaming=streaming
)
return await func(*args, **kwargs, llm=llm)
try:
return await attempt_function_call()
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error encountered for {provider} after multiple retries: {str(e)}")
raise
else:
print(f"No API keys found for provider: {provider}")
raise
else:
@wraps(func)
def wrapper(*args, **kwargs):
if key_count > 1:
all_keys = manager.get_all_api_keys(provider)
tried_keys = set()
current_max_tokens = max_tokens
token_reduction_attempts = 0
while len(tried_keys) < len(all_keys):
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens,
streaming=streaming
)
result = func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e:
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}")
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key")
else:
if delay_on_timeout > 0:
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...")
time.sleep(delay_on_timeout)
manager._current_indices[provider] = 0
else:
print(f"All {provider} API keys failed due to rate limits: {str(e)}")
raise
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
error_str = str(e)
if "token" in error_str.lower() or "context length" in error_str.lower():
print(f"Token limit error encountered: {error_str}")
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts:
current_max_tokens = int(current_max_tokens * 0.8)
token_reduction_attempts += 1
print(f"Retrying with reduced max_tokens: {current_max_tokens}")
continue # Retry with reduced max_tokens
else:
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.")
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key after token limit error.")
else:
raise # All keys tried, raise the token limit error
else:
# Re-raise other API errors
raise
# Attempt one final time after trying all keys (for rate limits with delay)
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens,
streaming=streaming
)
result = func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error after retrying all {provider} API keys: {str(e)}")
raise
elif key_count == 1:
@retry(
wait=wait_random_exponential(min=10, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type((
RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument))
)
def attempt_function_call():
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
streaming=streaming
)
return func(*args, **kwargs, llm=llm)
try:
return attempt_function_call()
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error encountered for {provider} after multiple retries: {str(e)}")
raise
else:
print(f"No API keys found for provider: {provider}")
raise
return wrapper
return decorator
if __name__ == "__main__":
import asyncio
prompt = "What is the capital of France?"
# Test key rotation
async def test_load_balancing(prompt: str, test_count: int = 10, stream: bool = False):
@with_api_manager(streaming=stream)
async def test(prompt: str, test_count: int = 10, *, llm):
print("="*50)
for i in range(test_count):
try:
print(f"\nTest {i+1} of {test_count}")
if stream:
async for chunk in llm.astream(prompt):
print(chunk.content, end="", flush=True)
print("\n" + "-"*50 if i != test_count - 1 else "\n" + "="*50)
else:
response = await llm.ainvoke(prompt)
print(f"Response: {response.content.strip()}")
print("-"*50) if i != test_count - 1 else print("="*50)
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e:
print(f"Error: {str(e)}")
raise
await test(prompt, test_count=test_count)
# Test without load balancing
def test_without_load_balancing(model_name: str, prompt: str, test_count: int = 10):
manager = APIKeyManager()
print(f"Using model: {model_name}")
print("="*50)
i = 0
while i < test_count:
try:
print(f"Test {i+1} of {test_count}")
llm = manager.get_llm(model_name=model_name)
response = llm.invoke(prompt)
print(f"Response: {response.content.strip()}")
print("-"*50) if i != test_count - 1 else print("="*50)
i += 1
except Exception as e:
raise Exception(f"Error with {model_name}: {str(e)}")
# test_without_load_balancing(model_name="gemini-exp-1121", prompt=prompt, test_count=50)
asyncio.run(test_load_balancing(prompt=prompt, test_count=100, stream=True))