Inferencing / utils.py
Shyamnath's picture
Add complete backend application with all dependencies
6ff1f88
raw
history blame
4.41 kB
import pandas as pd
from typing import Dict, List, Optional
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class KeyRotator:
def __init__(self):
self.current_indices: Dict[str, int] = {}
self.api_keys: Dict[str, List[str]] = {}
self.models_map: Dict[str, str] = {}
self._load_api_keys()
self._load_models()
def _load_api_keys(self) -> None:
"""Load API keys from CSV file and initialize rotation indices."""
try:
keys_df = pd.read_csv('apikeys.csv')
# Map model provider names to API key provider names
provider_map = {
"GROQ MODELS": "GROQ API ", # Note the trailing space
"COHERE": "COHERE AI",
"SambaNova": "Samba api key",
"GEMINI": "GEMINI API"
}
logger.info("Available API key columns: %s", keys_df.columns.tolist())
for model_provider, api_provider in provider_map.items():
logger.info(f"Processing provider mapping: {model_provider} -> {api_provider}")
if api_provider in keys_df.columns:
valid_keys = [key for key in keys_df[api_provider].dropna()]
logger.info(f"Found {len(valid_keys)} valid keys for {model_provider}")
if valid_keys:
self.api_keys[model_provider] = valid_keys
self.current_indices[model_provider] = 0
else:
logger.warning(f"API provider column {api_provider} not found in CSV")
logger.warning(f"Available columns: {keys_df.columns.tolist()}")
except Exception as e:
logger.error(f"Error loading API keys: {str(e)}")
raise Exception(f"Error loading API keys: {str(e)}")
def _load_models(self) -> None:
"""Load models from CSV file and create model-to-provider mapping."""
try:
models_df = pd.read_csv('MODELS.csv')
logger.info("Available model providers: %s", models_df.columns.tolist())
for column in models_df.columns:
provider = column.strip()
models = [model for model in models_df[column].dropna()]
logger.info(f"Loading {len(models)} models for provider {provider}")
for model in models:
self.models_map[model.strip()] = provider
except Exception as e:
logger.error(f"Error loading models: {str(e)}")
raise Exception(f"Error loading models: {str(e)}")
def get_next_key(self, provider: str) -> str:
"""Get the next API key for the specified provider using rotation."""
if provider not in self.api_keys:
logger.error(f"No API keys found for provider: {provider}")
logger.error(f"Available providers: {list(self.api_keys.keys())}")
raise ValueError(f"No API keys found for provider: {provider}")
keys = self.api_keys[provider]
current_idx = self.current_indices[provider]
total_keys = len(keys)
# Get current key and update index for next time
key = keys[current_idx]
next_idx = (current_idx + 1) % total_keys
self.current_indices[provider] = next_idx
# Log key rotation info with more details
key_preview = f"{key[:4]}...{key[-4:]}"
logger.info(f"### API Key Rotation for {provider} ###")
logger.info(f"Using key #{current_idx + 1}/{total_keys} ({key_preview})")
logger.info(f"Next request will use key #{next_idx + 1}/{total_keys}")
logger.info(f"#################################")
return key
def get_provider_for_model(self, model: str) -> str:
"""Get the provider name for a given model."""
model = model.strip()
if model not in self.models_map:
raise ValueError(f"Model not found: {model}")
return self.models_map[model]
def get_available_models(self) -> List[Dict[str, str]]:
"""Get list of all available models with their providers."""
return [{"model": model, "provider": provider}
for model, provider in self.models_map.items()]
# Global instance
key_rotator = KeyRotator()