from threading import Lock import os from typing import List, Optional, Literal, Union, Dict from dotenv import load_dotenv import re from langchain_xai import ChatXAI from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI from functools import wraps import time from openai import RateLimitError, OpenAIError from anthropic import RateLimitError as AnthropicRateLimitError, APIError as AnthropicAPIError from google.api_core.exceptions import ResourceExhausted, BadRequest, InvalidArgument from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type import asyncio ModelProvider = Literal["openai", "anthropic", "google", "xai"] class APIKeyManager: _instance = None _lock = Lock() # Define supported models SUPPORTED_MODELS = { "openai": [ "gpt-3.5-turbo", "gpt-3.5-turbo-instruct", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", "gpt-4-0314", "gpt-4-0613", "gpt-4", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4-turbo-preview", "gpt-4-turbo-2024-04-09", "gpt-4-turbo", "o1-mini-2024-09-12", "o1-mini", "o1-preview-2024-09-12", "o1-preview", "o1", "gpt-4o-mini-2024-07-18", "gpt-4o-mini", "chatgpt-4o-latest", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o" ], "google": [ "gemini-1.5-flash", "gemini-1.5-flash-latest", "gemini-1.5-flash-exp-0827", "gemini-1.5-flash-001", "gemini-1.5-flash-002", "gemini-1.5-flash-8b-exp-0924", "gemini-1.5-flash-8b-exp-0827", "gemini-1.5-flash-8b-001", "gemini-1.5-flash-8b", "gemini-1.5-flash-8b-latest", "gemini-1.5-pro", "gemini-1.5-pro-latest", "gemini-1.5-pro-001", "gemini-1.5-pro-002", "gemini-1.5-pro-exp-0827", "gemini-1.0-pro", "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-pro", "gemini-exp-1114", "gemini-exp-1121", "gemini-2.0-pro-exp-02-05", "gemini-2.0-flash-lite-preview-02-05", "gemini-2.0-flash-exp", "gemini-2.0-flash", "gemini-2.0-flash-thinking-exp-1219", ], "xai": [ "grok-beta", "grok-vision-beta", "grok-2-vision-1212", "grok-2-1212" ], "anthropic": [ "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-latest", "claude-3-5-haiku-20241022", "claude-3-5-haiku-latest", "claude-3-opus-20240229", "claude-3-opus-latest", "claude-3-sonnet-20240229", "claude-3-haiku-20240307" ] } def __new__(cls): with cls._lock: if cls._instance is None: cls._instance = super(APIKeyManager, cls).__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if not self._initialized: self._initialized = True # 1) Always load env load_dotenv(override=True) self._current_indices = { "openai": 0, "anthropic": 0, "google": 0, "xai": 0 } self._lock = Lock() # 2) load all provider keys from environment self._api_keys = self._load_api_keys() self._llm = None self._current_provider = None # 3) read user’s chosen provider, model, temperature, top_p from env provider_env = os.getenv("MODEL_PROVIDER", "openai").strip().lower() self.model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo").strip() temp_str = os.getenv("MODEL_TEMPERATURE", "0") topp_str = os.getenv("MODEL_TOP_P", "1") try: self.temperature = float(temp_str) except ValueError: self.temperature = 0.0 try: self.top_p = float(topp_str) except ValueError: self.top_p = 1.0 def _reinit(self): self._initialized = False self.__init__() def _load_api_keys(self) -> Dict[str, List[str]]: """Load API keys from environment variables dynamically.""" api_keys = { "openai": [], "anthropic": [], "google": [], "xai": [] } # Get all environment variables env_vars = dict(os.environ) # Load OpenAI API keys openai_pattern = re.compile(r'OPENAI_API_KEY_\d+$') openai_keys = {k: v for k, v in env_vars.items() if openai_pattern.match(k) and v.strip()} if not openai_keys: default_key = os.getenv('OPENAI_API_KEY') if default_key and default_key.strip(): api_keys["openai"].append(default_key) else: sorted_keys = sorted(openai_keys.keys(), key=lambda x: int(x.split('_')[-1])) for key_name in sorted_keys: api_key = openai_keys[key_name] if api_key and api_key.strip(): api_keys["openai"].append(api_key) # Load Google API keys google_pattern = re.compile(r'GOOGLE_API_KEY_\d+$') google_keys = {k: v for k, v in env_vars.items() if google_pattern.match(k) and v.strip()} if not google_keys: default_key = os.getenv('GOOGLE_API_KEY') if default_key and default_key.strip(): api_keys["google"].append(default_key) else: sorted_keys = sorted(google_keys.keys(), key=lambda x: int(x.split('_')[-1])) for key_name in sorted_keys: api_key = google_keys[key_name] if api_key and api_key.strip(): api_keys["google"].append(api_key) # Load XAI API keys xai_pattern = re.compile(r'XAI_API_KEY_\d+$') xai_keys = {k: v for k, v in env_vars.items() if xai_pattern.match(k) and v.strip()} if not xai_keys: default_key = os.getenv('XAI_API_KEY') if default_key and default_key.strip(): api_keys["xai"].append(default_key) else: sorted_keys = sorted(xai_keys.keys(), key=lambda x: int(x.split('_')[-1])) for key_name in sorted_keys: api_key = xai_keys[key_name] if api_key and api_key.strip(): api_keys["xai"].append(api_key) # Load Anthropic API keys anthropic_pattern = re.compile(r'ANTHROPIC_API_KEY_\d+$') anthropic_keys = {k: v for k, v in env_vars.items() if anthropic_pattern.match(k) and v.strip()} if not anthropic_keys: default_key = os.getenv('ANTHROPIC_API_KEY') if default_key and default_key.strip(): api_keys["anthropic"].append(default_key) else: sorted_keys = sorted(anthropic_keys.keys(), key=lambda x: int(x.split('_')[-1])) for key_name in sorted_keys: api_key = anthropic_keys[key_name] if api_key and api_key.strip(): api_keys["anthropic"].append(api_key) if not any(api_keys.values()): raise Exception("No valid API keys found in environment variables") for provider, keys in api_keys.items(): if keys: print(f"Loaded {len(keys)} {provider} API keys for rotation") return api_keys def get_next_api_key(self, provider: ModelProvider) -> str: """Get the next API key in round-robin fashion for the specified provider.""" with self._lock: if not self._api_keys.get(provider) or len(self._api_keys[provider]) == 0: raise Exception(f"No API key found for {provider}") if provider not in self._current_indices: self._current_indices[provider] = 0 current_key = self._api_keys[provider][self._current_indices[provider]] self._current_indices[provider] = (self._current_indices[provider] + 1) % len(self._api_keys[provider]) return current_key def _get_provider_for_model(self) -> ModelProvider: """Determine the provider based on the model name.""" load_dotenv(override=True) # to refresh in case .env changed provider_env = os.getenv("MODEL_PROVIDER", "openai").lower().strip() if provider_env not in self.SUPPORTED_MODELS: raise Exception( f"Invalid or missing MODEL_PROVIDER in env: '{provider_env}'. " f"Must be one of: {list(self.SUPPORTED_MODELS.keys())}" ) # check if user-chosen model is in that provider’s list if self.model_name not in self.SUPPORTED_MODELS[provider_env]: available = self.SUPPORTED_MODELS[provider_env] raise Exception( f"Model '{self.model_name}' is not available under provider '{provider_env}'. " f"Available: {available}" ) return provider_env def _initialize_llm( self, model_name: Optional[str] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, streaming: bool = False ): """Initialize LLM with the next API key in rotation.""" load_dotenv(override=True) # refresh .env in case it changed provider = self._get_provider_for_model() model_name = model_name if model_name else self.model_name temperature = temperature if temperature else self.temperature top_p = top_p if top_p else self.top_p api_key = self.get_next_api_key(provider) print(f"Using provider={provider}, model_name={model_name}, " f"temperature={temperature}, top_p={top_p}, key={api_key}") kwargs = { "model": model_name, "temperature": temperature, "top_p": top_p, "max_retries": 0, "streaming": streaming, "api_key": api_key, } if max_tokens is not None: kwargs["max_tokens"] = max_tokens if provider == "openai": self._llm = ChatOpenAI(**kwargs) elif provider == "google": self._llm = ChatGoogleGenerativeAI(**kwargs) elif provider == "anthropic": self._llm = ChatAnthropic(**kwargs) else: self._llm = ChatXAI(**kwargs) self._current_provider = provider def get_llm( self, model_name: Optional[str] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, streaming: bool = False ) -> Union[ChatOpenAI, ChatGoogleGenerativeAI, ChatAnthropic, ChatXAI]: """Get LLM instance with the current API key.""" provider = self._get_provider_for_model() model_name = model_name if model_name else self.model_name temperature = temperature if temperature else self.temperature top_p = top_p if top_p else self.top_p if self._llm is None or provider != self._current_provider: self._initialize_llm(model_name, temperature, top_p, max_tokens, streaming) return self._llm def rotate_key(self, provider: Optional[ModelProvider] = None, streaming: bool = False) -> None: """Manually rotate to the next API key.""" if provider is None: provider = self._current_provider self._initialize_llm(streaming=streaming) def get_all_api_keys(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, List[str]], List[str]]: """Get all available API keys.""" if provider: return self._api_keys[provider].copy() return {k: v.copy() for k, v in self._api_keys.items()} def get_key_count(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, int], int]: """Get the total number of available API keys.""" if provider: return len(self._api_keys[provider]) return {k: len(v) for k, v in self._api_keys.items()} def __len__(self) -> Dict[str, int]: """Get the number of active API keys for each provider.""" return self.get_key_count() def __bool__(self) -> bool: """Check if there are any API keys available.""" return any(bool(keys) for keys in self._api_keys.values()) def with_api_manager( model_name: Optional[str] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, streaming: bool = False, delay_on_timeout: int = 20, max_token_reduction_attempts: int = 0 ): """Decorator for automatic key rotation on error with delay on timeout.""" manager = APIKeyManager() provider = manager._get_provider_for_model() model_name = model_name if model_name else manager.model_name temperature = temperature if temperature else manager.temperature top_p = top_p if top_p else manager.top_p key_count = manager.get_key_count(provider) def decorator(func): if asyncio.iscoroutinefunction(func): @wraps(func) async def wrapper(*args, **kwargs): if key_count > 1: all_keys = manager.get_all_api_keys(provider) tried_keys = set() current_max_tokens = max_tokens token_reduction_attempts = 0 while len(tried_keys) < len(all_keys): try: llm = manager.get_llm( model_name=model_name, temperature=temperature, top_p=top_p, max_tokens=current_max_tokens, streaming=streaming ) result = await func(*args, **kwargs, llm=llm) return result except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") tried_keys.add(current_key) if len(tried_keys) < len(all_keys): manager.rotate_key(provider=provider, streaming=streaming) print(f"Using next available {provider} API key") else: if delay_on_timeout > 0: print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") time.sleep(delay_on_timeout) manager._current_indices[provider] = 0 else: print(f"All {provider} API keys failed due to rate limits: {str(e)}") raise except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: error_str = str(e) if "token" in error_str.lower() or "context length" in error_str.lower(): print(f"Token limit error encountered: {error_str}") if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: current_max_tokens = int(current_max_tokens * 0.8) # Reduce the local variable token_reduction_attempts += 1 print(f"Retrying with reduced max_tokens: {current_max_tokens}") continue # Retry with reduced max_tokens else: print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] tried_keys.add(current_key) if len(tried_keys) < len(all_keys): manager.rotate_key(provider=provider, streaming=streaming) print(f"Using next available {provider} API key after token limit error.") else: raise # All keys tried, raise the token limit error else: # Re-raise other API errors raise # Attempt one final time after trying all keys (for rate limits with delay) try: llm = manager.get_llm( model_name=model_name, temperature=temperature, top_p=top_p, max_tokens=current_max_tokens, # Use the current value streaming=streaming ) result = await func(*args, **kwargs, llm=llm) return result except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: print(f"Error after retrying all {provider} API keys: {str(e)}") raise elif key_count == 1: @retry( wait=wait_random_exponential(min=10, max=60), stop=stop_after_attempt(6), retry=retry_if_exception_type(( RateLimitError, ResourceExhausted, AnthropicRateLimitError, OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument)) ) async def attempt_function_call(): llm = manager.get_llm( model_name=model_name, temperature=temperature, top_p=top_p, max_tokens=max_tokens, streaming=streaming ) return await func(*args, **kwargs, llm=llm) try: return await attempt_function_call() except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: print(f"Error encountered for {provider} after multiple retries: {str(e)}") raise else: print(f"No API keys found for provider: {provider}") raise else: @wraps(func) def wrapper(*args, **kwargs): if key_count > 1: all_keys = manager.get_all_api_keys(provider) tried_keys = set() current_max_tokens = max_tokens token_reduction_attempts = 0 while len(tried_keys) < len(all_keys): try: llm = manager.get_llm( model_name=model_name, temperature=temperature, top_p=top_p, max_tokens=current_max_tokens, streaming=streaming ) result = func(*args, **kwargs, llm=llm) return result except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") tried_keys.add(current_key) if len(tried_keys) < len(all_keys): manager.rotate_key(provider=provider, streaming=streaming) print(f"Using next available {provider} API key") else: if delay_on_timeout > 0: print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") time.sleep(delay_on_timeout) manager._current_indices[provider] = 0 else: print(f"All {provider} API keys failed due to rate limits: {str(e)}") raise except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: error_str = str(e) if "token" in error_str.lower() or "context length" in error_str.lower(): print(f"Token limit error encountered: {error_str}") if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: current_max_tokens = int(current_max_tokens * 0.8) token_reduction_attempts += 1 print(f"Retrying with reduced max_tokens: {current_max_tokens}") continue # Retry with reduced max_tokens else: print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] tried_keys.add(current_key) if len(tried_keys) < len(all_keys): manager.rotate_key(provider=provider, streaming=streaming) print(f"Using next available {provider} API key after token limit error.") else: raise # All keys tried, raise the token limit error else: # Re-raise other API errors raise # Attempt one final time after trying all keys (for rate limits with delay) try: llm = manager.get_llm( model_name=model_name, temperature=temperature, top_p=top_p, max_tokens=current_max_tokens, streaming=streaming ) result = func(*args, **kwargs, llm=llm) return result except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: print(f"Error after retrying all {provider} API keys: {str(e)}") raise elif key_count == 1: @retry( wait=wait_random_exponential(min=10, max=60), stop=stop_after_attempt(6), retry=retry_if_exception_type(( RateLimitError, ResourceExhausted, AnthropicRateLimitError, OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument)) ) def attempt_function_call(): llm = manager.get_llm( model_name=model_name, temperature=temperature, top_p=top_p, max_tokens=max_tokens, streaming=streaming ) return func(*args, **kwargs, llm=llm) try: return attempt_function_call() except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: print(f"Error encountered for {provider} after multiple retries: {str(e)}") raise else: print(f"No API keys found for provider: {provider}") raise return wrapper return decorator if __name__ == "__main__": import asyncio prompt = "What is the capital of France?" # Test key rotation async def test_load_balancing(prompt: str, test_count: int = 10, stream: bool = False): @with_api_manager(streaming=stream) async def test(prompt: str, test_count: int = 10, *, llm): print("="*50) for i in range(test_count): try: print(f"\nTest {i+1} of {test_count}") if stream: async for chunk in llm.astream(prompt): print(chunk.content, end="", flush=True) print("\n" + "-"*50 if i != test_count - 1 else "\n" + "="*50) else: response = await llm.ainvoke(prompt) print(f"Response: {response.content.strip()}") print("-"*50) if i != test_count - 1 else print("="*50) except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: print(f"Error: {str(e)}") raise await test(prompt, test_count=test_count) # Test without load balancing def test_without_load_balancing(model_name: str, prompt: str, test_count: int = 10): manager = APIKeyManager() print(f"Using model: {model_name}") print("="*50) i = 0 while i < test_count: try: print(f"Test {i+1} of {test_count}") llm = manager.get_llm(model_name=model_name) response = llm.invoke(prompt) print(f"Response: {response.content.strip()}") print("-"*50) if i != test_count - 1 else print("="*50) i += 1 except Exception as e: raise Exception(f"Error with {model_name}: {str(e)}") # test_without_load_balancing(model_name="gemini-exp-1121", prompt=prompt, test_count=50) asyncio.run(test_load_balancing(prompt=prompt, test_count=100, stream=True))