Spaces:
Paused
Paused
from threading import Lock | |
import os | |
from typing import List, Optional, Literal, Union, Dict | |
from dotenv import load_dotenv | |
import re | |
from langchain_xai import ChatXAI | |
from langchain_openai import ChatOpenAI | |
from langchain_anthropic import ChatAnthropic | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from functools import wraps | |
import time | |
from openai import RateLimitError, OpenAIError | |
from anthropic import RateLimitError as AnthropicRateLimitError, APIError as AnthropicAPIError | |
from google.api_core.exceptions import ResourceExhausted, BadRequest, InvalidArgument | |
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type | |
import asyncio | |
ModelProvider = Literal["openai", "anthropic", "google", "xai"] | |
class APIKeyManager: | |
_instance = None | |
_lock = Lock() | |
# Define supported models | |
SUPPORTED_MODELS = { | |
"openai": [ | |
"gpt-3.5-turbo", | |
"gpt-3.5-turbo-instruct", | |
"gpt-3.5-turbo-1106", | |
"gpt-3.5-turbo-0125", | |
"gpt-4-0314", | |
"gpt-4-0613", | |
"gpt-4", | |
"gpt-4-1106-preview", | |
"gpt-4-0125-preview", | |
"gpt-4-turbo-preview", | |
"gpt-4-turbo-2024-04-09", | |
"gpt-4-turbo", | |
"o1-mini-2024-09-12", | |
"o1-mini", | |
"o1-preview-2024-09-12", | |
"o1-preview", | |
"o1", | |
"gpt-4o-mini-2024-07-18", | |
"gpt-4o-mini", | |
"chatgpt-4o-latest", | |
"gpt-4o-2024-05-13", | |
"gpt-4o-2024-08-06", | |
"gpt-4o-2024-11-20", | |
"gpt-4o" | |
], | |
"google": [ | |
"gemini-1.5-flash", | |
"gemini-1.5-flash-latest", | |
"gemini-1.5-flash-exp-0827", | |
"gemini-1.5-flash-001", | |
"gemini-1.5-flash-002", | |
"gemini-1.5-flash-8b-exp-0924", | |
"gemini-1.5-flash-8b-exp-0827", | |
"gemini-1.5-flash-8b-001", | |
"gemini-1.5-flash-8b", | |
"gemini-1.5-flash-8b-latest", | |
"gemini-1.5-pro", | |
"gemini-1.5-pro-latest", | |
"gemini-1.5-pro-001", | |
"gemini-1.5-pro-002", | |
"gemini-1.5-pro-exp-0827", | |
"gemini-1.0-pro", | |
"gemini-1.0-pro-latest", | |
"gemini-1.0-pro-001", | |
"gemini-pro", | |
"gemini-exp-1114", | |
"gemini-exp-1121", | |
"gemini-2.0-pro-exp-02-05", | |
"gemini-2.0-flash-lite-preview-02-05", | |
"gemini-2.0-flash-exp", | |
"gemini-2.0-flash", | |
"gemini-2.0-flash-thinking-exp-1219", | |
], | |
"xai": [ | |
"grok-beta", | |
"grok-vision-beta", | |
"grok-2-vision-1212", | |
"grok-2-1212" | |
], | |
"anthropic": [ | |
"claude-3-5-sonnet-20241022", | |
"claude-3-5-sonnet-latest", | |
"claude-3-5-haiku-20241022", | |
"claude-3-5-haiku-latest", | |
"claude-3-opus-20240229", | |
"claude-3-opus-latest", | |
"claude-3-sonnet-20240229", | |
"claude-3-haiku-20240307" | |
] | |
} | |
def __new__(cls): | |
with cls._lock: | |
if cls._instance is None: | |
cls._instance = super(APIKeyManager, cls).__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if not self._initialized: | |
self._initialized = True | |
# 1) Always load env | |
load_dotenv(override=True) | |
self._current_indices = { | |
"openai": 0, | |
"anthropic": 0, | |
"google": 0, | |
"xai": 0 | |
} | |
self._lock = Lock() | |
# 2) load all provider keys from environment | |
self._api_keys = self._load_api_keys() | |
self._llm = None | |
self._current_provider = None | |
# 3) read user’s chosen provider, model, temperature, top_p from env | |
provider_env = os.getenv("MODEL_PROVIDER", "openai").strip().lower() | |
self.model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo").strip() | |
temp_str = os.getenv("MODEL_TEMPERATURE", "0") | |
topp_str = os.getenv("MODEL_TOP_P", "1") | |
try: | |
self.temperature = float(temp_str) | |
except ValueError: | |
self.temperature = 0.0 | |
try: | |
self.top_p = float(topp_str) | |
except ValueError: | |
self.top_p = 1.0 | |
def _reinit(self): | |
self._initialized = False | |
self.__init__() | |
def _load_api_keys(self) -> Dict[str, List[str]]: | |
"""Load API keys from environment variables dynamically.""" | |
api_keys = { | |
"openai": [], | |
"anthropic": [], | |
"google": [], | |
"xai": [] | |
} | |
# Get all environment variables | |
env_vars = dict(os.environ) | |
# Load OpenAI API keys | |
openai_pattern = re.compile(r'OPENAI_API_KEY_\d+$') | |
openai_keys = {k: v for k, v in env_vars.items() if openai_pattern.match(k) and v.strip()} | |
if not openai_keys: | |
default_key = os.getenv('OPENAI_API_KEY') | |
if default_key and default_key.strip(): | |
api_keys["openai"].append(default_key) | |
else: | |
sorted_keys = sorted(openai_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
for key_name in sorted_keys: | |
api_key = openai_keys[key_name] | |
if api_key and api_key.strip(): | |
api_keys["openai"].append(api_key) | |
# Load Google API keys | |
google_pattern = re.compile(r'GOOGLE_API_KEY_\d+$') | |
google_keys = {k: v for k, v in env_vars.items() if google_pattern.match(k) and v.strip()} | |
if not google_keys: | |
default_key = os.getenv('GOOGLE_API_KEY') | |
if default_key and default_key.strip(): | |
api_keys["google"].append(default_key) | |
else: | |
sorted_keys = sorted(google_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
for key_name in sorted_keys: | |
api_key = google_keys[key_name] | |
if api_key and api_key.strip(): | |
api_keys["google"].append(api_key) | |
# Load XAI API keys | |
xai_pattern = re.compile(r'XAI_API_KEY_\d+$') | |
xai_keys = {k: v for k, v in env_vars.items() if xai_pattern.match(k) and v.strip()} | |
if not xai_keys: | |
default_key = os.getenv('XAI_API_KEY') | |
if default_key and default_key.strip(): | |
api_keys["xai"].append(default_key) | |
else: | |
sorted_keys = sorted(xai_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
for key_name in sorted_keys: | |
api_key = xai_keys[key_name] | |
if api_key and api_key.strip(): | |
api_keys["xai"].append(api_key) | |
# Load Anthropic API keys | |
anthropic_pattern = re.compile(r'ANTHROPIC_API_KEY_\d+$') | |
anthropic_keys = {k: v for k, v in env_vars.items() if anthropic_pattern.match(k) and v.strip()} | |
if not anthropic_keys: | |
default_key = os.getenv('ANTHROPIC_API_KEY') | |
if default_key and default_key.strip(): | |
api_keys["anthropic"].append(default_key) | |
else: | |
sorted_keys = sorted(anthropic_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
for key_name in sorted_keys: | |
api_key = anthropic_keys[key_name] | |
if api_key and api_key.strip(): | |
api_keys["anthropic"].append(api_key) | |
if not any(api_keys.values()): | |
raise Exception("No valid API keys found in environment variables") | |
for provider, keys in api_keys.items(): | |
if keys: | |
print(f"Loaded {len(keys)} {provider} API keys for rotation") | |
return api_keys | |
def get_next_api_key(self, provider: ModelProvider) -> str: | |
"""Get the next API key in round-robin fashion for the specified provider.""" | |
with self._lock: | |
if not self._api_keys.get(provider) or len(self._api_keys[provider]) == 0: | |
raise Exception(f"No API key found for {provider}") | |
if provider not in self._current_indices: | |
self._current_indices[provider] = 0 | |
current_key = self._api_keys[provider][self._current_indices[provider]] | |
self._current_indices[provider] = (self._current_indices[provider] + 1) % len(self._api_keys[provider]) | |
return current_key | |
def _get_provider_for_model(self) -> ModelProvider: | |
"""Determine the provider based on the model name.""" | |
load_dotenv(override=True) # to refresh in case .env changed | |
provider_env = os.getenv("MODEL_PROVIDER", "openai").lower().strip() | |
if provider_env not in self.SUPPORTED_MODELS: | |
raise Exception( | |
f"Invalid or missing MODEL_PROVIDER in env: '{provider_env}'. " | |
f"Must be one of: {list(self.SUPPORTED_MODELS.keys())}" | |
) | |
# check if user-chosen model is in that provider’s list | |
if self.model_name not in self.SUPPORTED_MODELS[provider_env]: | |
available = self.SUPPORTED_MODELS[provider_env] | |
raise Exception( | |
f"Model '{self.model_name}' is not available under provider '{provider_env}'. " | |
f"Available: {available}" | |
) | |
return provider_env | |
def _initialize_llm( | |
self, | |
model_name: Optional[str] = None, | |
temperature: Optional[float] = None, | |
top_p: Optional[float] = None, | |
max_tokens: Optional[int] = None, | |
streaming: bool = False | |
): | |
"""Initialize LLM with the next API key in rotation.""" | |
load_dotenv(override=True) # refresh .env in case it changed | |
provider = self._get_provider_for_model() | |
model_name = model_name if model_name else self.model_name | |
temperature = temperature if temperature else self.temperature | |
top_p = top_p if top_p else self.top_p | |
api_key = self.get_next_api_key(provider) | |
print(f"Using provider={provider}, model_name={model_name}, " | |
f"temperature={temperature}, top_p={top_p}, key={api_key}") | |
kwargs = { | |
"model": model_name, | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_retries": 0, | |
"streaming": streaming, | |
"api_key": api_key, | |
} | |
if max_tokens is not None: | |
kwargs["max_tokens"] = max_tokens | |
if provider == "openai": | |
self._llm = ChatOpenAI(**kwargs) | |
elif provider == "google": | |
self._llm = ChatGoogleGenerativeAI(**kwargs) | |
elif provider == "anthropic": | |
self._llm = ChatAnthropic(**kwargs) | |
else: | |
self._llm = ChatXAI(**kwargs) | |
self._current_provider = provider | |
def get_llm( | |
self, | |
model_name: Optional[str] = None, | |
temperature: Optional[float] = None, | |
top_p: Optional[float] = None, | |
max_tokens: Optional[int] = None, | |
streaming: bool = False | |
) -> Union[ChatOpenAI, ChatGoogleGenerativeAI, ChatAnthropic, ChatXAI]: | |
"""Get LLM instance with the current API key.""" | |
provider = self._get_provider_for_model() | |
model_name = model_name if model_name else self.model_name | |
temperature = temperature if temperature else self.temperature | |
top_p = top_p if top_p else self.top_p | |
if self._llm is None or provider != self._current_provider: | |
self._initialize_llm(model_name, temperature, top_p, max_tokens, streaming) | |
return self._llm | |
def rotate_key(self, provider: Optional[ModelProvider] = None, streaming: bool = False) -> None: | |
"""Manually rotate to the next API key.""" | |
if provider is None: | |
provider = self._current_provider | |
self._initialize_llm(streaming=streaming) | |
def get_all_api_keys(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, List[str]], List[str]]: | |
"""Get all available API keys.""" | |
if provider: | |
return self._api_keys[provider].copy() | |
return {k: v.copy() for k, v in self._api_keys.items()} | |
def get_key_count(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, int], int]: | |
"""Get the total number of available API keys.""" | |
if provider: | |
return len(self._api_keys[provider]) | |
return {k: len(v) for k, v in self._api_keys.items()} | |
def __len__(self) -> Dict[str, int]: | |
"""Get the number of active API keys for each provider.""" | |
return self.get_key_count() | |
def __bool__(self) -> bool: | |
"""Check if there are any API keys available.""" | |
return any(bool(keys) for keys in self._api_keys.values()) | |
def with_api_manager( | |
model_name: Optional[str] = None, | |
temperature: Optional[float] = None, | |
top_p: Optional[float] = None, | |
max_tokens: Optional[int] = None, | |
streaming: bool = False, | |
delay_on_timeout: int = 20, | |
max_token_reduction_attempts: int = 0 | |
): | |
"""Decorator for automatic key rotation on error with delay on timeout.""" | |
manager = APIKeyManager() | |
provider = manager._get_provider_for_model() | |
model_name = model_name if model_name else manager.model_name | |
temperature = temperature if temperature else manager.temperature | |
top_p = top_p if top_p else manager.top_p | |
key_count = manager.get_key_count(provider) | |
def decorator(func): | |
if asyncio.iscoroutinefunction(func): | |
async def wrapper(*args, **kwargs): | |
if key_count > 1: | |
all_keys = manager.get_all_api_keys(provider) | |
tried_keys = set() | |
current_max_tokens = max_tokens | |
token_reduction_attempts = 0 | |
while len(tried_keys) < len(all_keys): | |
try: | |
llm = manager.get_llm( | |
model_name=model_name, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=current_max_tokens, | |
streaming=streaming | |
) | |
result = await func(*args, **kwargs, llm=llm) | |
return result | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: | |
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") | |
tried_keys.add(current_key) | |
if len(tried_keys) < len(all_keys): | |
manager.rotate_key(provider=provider, streaming=streaming) | |
print(f"Using next available {provider} API key") | |
else: | |
if delay_on_timeout > 0: | |
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") | |
time.sleep(delay_on_timeout) | |
manager._current_indices[provider] = 0 | |
else: | |
print(f"All {provider} API keys failed due to rate limits: {str(e)}") | |
raise | |
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
error_str = str(e) | |
if "token" in error_str.lower() or "context length" in error_str.lower(): | |
print(f"Token limit error encountered: {error_str}") | |
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: | |
current_max_tokens = int(current_max_tokens * 0.8) # Reduce the local variable | |
token_reduction_attempts += 1 | |
print(f"Retrying with reduced max_tokens: {current_max_tokens}") | |
continue # Retry with reduced max_tokens | |
else: | |
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") | |
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
tried_keys.add(current_key) | |
if len(tried_keys) < len(all_keys): | |
manager.rotate_key(provider=provider, streaming=streaming) | |
print(f"Using next available {provider} API key after token limit error.") | |
else: | |
raise # All keys tried, raise the token limit error | |
else: | |
# Re-raise other API errors | |
raise | |
# Attempt one final time after trying all keys (for rate limits with delay) | |
try: | |
llm = manager.get_llm( | |
model_name=model_name, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=current_max_tokens, # Use the current value | |
streaming=streaming | |
) | |
result = await func(*args, **kwargs, llm=llm) | |
return result | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
print(f"Error after retrying all {provider} API keys: {str(e)}") | |
raise | |
elif key_count == 1: | |
async def attempt_function_call(): | |
llm = manager.get_llm( | |
model_name=model_name, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=max_tokens, | |
streaming=streaming | |
) | |
return await func(*args, **kwargs, llm=llm) | |
try: | |
return await attempt_function_call() | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
print(f"Error encountered for {provider} after multiple retries: {str(e)}") | |
raise | |
else: | |
print(f"No API keys found for provider: {provider}") | |
raise | |
else: | |
def wrapper(*args, **kwargs): | |
if key_count > 1: | |
all_keys = manager.get_all_api_keys(provider) | |
tried_keys = set() | |
current_max_tokens = max_tokens | |
token_reduction_attempts = 0 | |
while len(tried_keys) < len(all_keys): | |
try: | |
llm = manager.get_llm( | |
model_name=model_name, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=current_max_tokens, | |
streaming=streaming | |
) | |
result = func(*args, **kwargs, llm=llm) | |
return result | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: | |
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") | |
tried_keys.add(current_key) | |
if len(tried_keys) < len(all_keys): | |
manager.rotate_key(provider=provider, streaming=streaming) | |
print(f"Using next available {provider} API key") | |
else: | |
if delay_on_timeout > 0: | |
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") | |
time.sleep(delay_on_timeout) | |
manager._current_indices[provider] = 0 | |
else: | |
print(f"All {provider} API keys failed due to rate limits: {str(e)}") | |
raise | |
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
error_str = str(e) | |
if "token" in error_str.lower() or "context length" in error_str.lower(): | |
print(f"Token limit error encountered: {error_str}") | |
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: | |
current_max_tokens = int(current_max_tokens * 0.8) | |
token_reduction_attempts += 1 | |
print(f"Retrying with reduced max_tokens: {current_max_tokens}") | |
continue # Retry with reduced max_tokens | |
else: | |
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") | |
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
tried_keys.add(current_key) | |
if len(tried_keys) < len(all_keys): | |
manager.rotate_key(provider=provider, streaming=streaming) | |
print(f"Using next available {provider} API key after token limit error.") | |
else: | |
raise # All keys tried, raise the token limit error | |
else: | |
# Re-raise other API errors | |
raise | |
# Attempt one final time after trying all keys (for rate limits with delay) | |
try: | |
llm = manager.get_llm( | |
model_name=model_name, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=current_max_tokens, | |
streaming=streaming | |
) | |
result = func(*args, **kwargs, llm=llm) | |
return result | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
print(f"Error after retrying all {provider} API keys: {str(e)}") | |
raise | |
elif key_count == 1: | |
def attempt_function_call(): | |
llm = manager.get_llm( | |
model_name=model_name, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=max_tokens, | |
streaming=streaming | |
) | |
return func(*args, **kwargs, llm=llm) | |
try: | |
return attempt_function_call() | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
print(f"Error encountered for {provider} after multiple retries: {str(e)}") | |
raise | |
else: | |
print(f"No API keys found for provider: {provider}") | |
raise | |
return wrapper | |
return decorator | |
if __name__ == "__main__": | |
import asyncio | |
prompt = "What is the capital of France?" | |
# Test key rotation | |
async def test_load_balancing(prompt: str, test_count: int = 10, stream: bool = False): | |
async def test(prompt: str, test_count: int = 10, *, llm): | |
print("="*50) | |
for i in range(test_count): | |
try: | |
print(f"\nTest {i+1} of {test_count}") | |
if stream: | |
async for chunk in llm.astream(prompt): | |
print(chunk.content, end="", flush=True) | |
print("\n" + "-"*50 if i != test_count - 1 else "\n" + "="*50) | |
else: | |
response = await llm.ainvoke(prompt) | |
print(f"Response: {response.content.strip()}") | |
print("-"*50) if i != test_count - 1 else print("="*50) | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: | |
print(f"Error: {str(e)}") | |
raise | |
await test(prompt, test_count=test_count) | |
# Test without load balancing | |
def test_without_load_balancing(model_name: str, prompt: str, test_count: int = 10): | |
manager = APIKeyManager() | |
print(f"Using model: {model_name}") | |
print("="*50) | |
i = 0 | |
while i < test_count: | |
try: | |
print(f"Test {i+1} of {test_count}") | |
llm = manager.get_llm(model_name=model_name) | |
response = llm.invoke(prompt) | |
print(f"Response: {response.content.strip()}") | |
print("-"*50) if i != test_count - 1 else print("="*50) | |
i += 1 | |
except Exception as e: | |
raise Exception(f"Error with {model_name}: {str(e)}") | |
# test_without_load_balancing(model_name="gemini-exp-1121", prompt=prompt, test_count=50) | |
asyncio.run(test_load_balancing(prompt=prompt, test_count=100, stream=True)) |