Shyamnath commited on
Commit
6ff1f88
·
1 Parent(s): 5f35036

Add complete backend application with all dependencies

Browse files
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ .env
5
+ .venv
6
+ env/
7
+ venv/
8
+ ENV/
9
+ *.log
10
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Create a non-root user and switch to it
11
+ RUN useradd -m -u 1000 user
12
+ USER user
13
+ ENV PATH="/home/user/.local/bin:$PATH"
14
+ ENV PYTHONPATH="/app:$PYTHONPATH"
15
+
16
+ # Copy requirements first to leverage Docker cache
17
+ COPY --chown=user requirements.txt .
18
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
19
+
20
+ # Copy the application code
21
+ COPY --chown=user app.py .
22
+ COPY --chown=user apikeys.csv .
23
+ COPY --chown=user MODELS.csv .
24
+ COPY --chown=user utils.py .
25
+ COPY --chown=user core ./core
26
+ COPY --chown=user models ./models
27
+
28
+ # Expose the port
29
+ EXPOSE 7860
30
+
31
+ # Start the application
32
+ CMD ["python", "app.py"]
MODELS.csv ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GROQ MODELS,COHERE ,SambaNova,GEMINI
2
+ deepseek-r1-distill-llama-70b,Command R,DeepSeek-R1-Distill-Llama-70B,gemini-2.0-flash
3
+ llama-3.3-70b-versatile,Command R7B,Llama-3.1-Swallow-70B-Instruct-v0.3,gemini-2.0-flash-lite
4
+ llama-3.3-70b-specdec,Command R+,Llama-3.1-Swallow-8B-Instruct-v0.3,gemini-2.0-flash-thinking-exp-01-21
5
+ llama-3.2-1b-preview,,Llama-3.1-Tulu-3-405B,gemini-1.5-flash
6
+ llama-3.2-3b-preview,,Meta-Llama-3.1-405B-Instruct,gemini-1.5-flash-8b
7
+ llama-3.1-8b-instant,,Meta-Llama-3.1-70B-Instruct,
8
+ llama3-70b-8192,,Meta-Llama-3.1-8B-Instruct,
9
+ llama3-8b-8192,,Meta-Llama-3.2-1B-Instruct,
10
+ llama-guard-3-8b,,Meta-Llama-3.2-3B-Instruct,
11
+ mixtral-8x7b-32768,,Meta-Llama-3.3-70B-Instruct,
12
+ gemma2-9b-it,,Qwen2.5-72B-Instruct,
13
+ llama-3.2-11b-vision-preview,,Qwen2.5-Coder-32B-Instruct,
14
+ llama-3.2-90b-vision-preview,,QwQ-32B-Preview,
apikeys.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ GROQ API ,GEMINI API,COHERE AI,Samba api key
2
+ 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,AIzaSyA5XHRFKUIAKKmvPyZA50H9bmXGuio22_o,GEHqd8ojtfK7Q2qTSrH1RNbLPceIXgdtKtsaG1Io,d7277c97-5abf-4f00-afc4-a05f2283b1c9
3
+ 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,AIzaSyCFWSxJ8Lwd8eeSrLfEyepU00_TuoGr3YE,IAcJ37niDUV3AuIzwzRLYnSTkrgou4VGKhgAWud0,7952d283-92c9-42aa-8511-52729a5375b3
4
+ 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,AIzaSyDw2aCCLrBn2fQIkfjo3mHoNVaOMZEmQPg,EN5HVdaRlruyxpQkMhALF2ViwTb2ZOyfN5oBZnvf,68fa8ba4-5329-493a-a733-495d5fa72386
5
+ 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,AIzaSyBbHEMA7zPTnfw3VAmYGVsQQcdAxO1O__Q,FfBExN70G0PAun5MRvw29rbq3soCJgwBb4TsVrFe,52de8b5e-4a32-4c6f-983f-5af048d14d96
6
+ ,AIzaSyBSjaqgysbmd74zXax-mQMuR4YeHjUH124,w4EAnXvuW0G3j44qPwoGaMqSaMr03EWOugidAGMs,4cb02d31-2b3f-4e94-b12d-4d694b6f0c48
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse, JSONResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.responses import FileResponse
6
+ from pydantic import BaseModel
7
+ import json
8
+ import logging
9
+ import logging.config
10
+ import os
11
+
12
+ from core.config import API_HOST, API_PORT, CORS_SETTINGS, LOG_CONFIG
13
+ from core.exceptions import APIError, handle_api_error
14
+ from core.text_generation import text_generator
15
+
16
+ # Configure logging
17
+ logging.config.dictConfig(LOG_CONFIG)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ app = FastAPI(title="AI Text Generation API",
21
+ description="API for text generation using multiple AI providers",
22
+ version="1.0.0")
23
+
24
+ # Enable CORS with specific headers for SSE
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"], # Update this in production
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ expose_headers=["Content-Type", "Cache-Control"]
32
+ )
33
+
34
+ # Mount static files
35
+ frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'frontend')
36
+ app.mount("/static", StaticFiles(directory=frontend_dir), name="static")
37
+
38
+ class PromptRequest(BaseModel):
39
+ model: str
40
+ prompt: str
41
+
42
+ @app.get("/")
43
+ async def read_root():
44
+ """Serve the frontend HTML."""
45
+ return FileResponse(os.path.join(frontend_dir, 'index.html'))
46
+
47
+ @app.get("/models")
48
+ async def get_models():
49
+ """Get list of all available models."""
50
+ try:
51
+ # Return models as a JSON array
52
+ return JSONResponse(content=text_generator.get_available_models())
53
+ except APIError as e:
54
+ error_response = handle_api_error(e)
55
+ raise HTTPException(
56
+ status_code=error_response["status_code"],
57
+ detail=error_response["detail"]
58
+ )
59
+ except Exception as e:
60
+ logger.error(f"Unexpected error in get_models: {str(e)}")
61
+ raise HTTPException(status_code=500, detail="Internal server error")
62
+
63
+ async def generate_stream(model: str, prompt: str):
64
+ """Stream generator for text generation."""
65
+ try:
66
+ async for chunk in text_generator.generate_stream(model, prompt):
67
+ # Add extra newline to ensure proper event separation
68
+ yield f"data: {json.dumps({'content': chunk})}\n\n"
69
+ except APIError as e:
70
+ error_response = handle_api_error(e)
71
+ yield f"data: {json.dumps({'error': error_response['detail']})}\n\n"
72
+ except Exception as e:
73
+ logger.error(f"Unexpected error in generate_stream: {str(e)}")
74
+ yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
75
+ finally:
76
+ yield "data: [DONE]\n\n"
77
+
78
+ @app.get("/generate")
79
+ @app.post("/generate")
80
+ async def generate_response(request: Request):
81
+ """Generate response using selected model (supports both GET and POST)."""
82
+ try:
83
+ # Handle both GET and POST methods
84
+ if request.method == "GET":
85
+ params = dict(request.query_params)
86
+ model = params.get("model")
87
+ prompt = params.get("prompt")
88
+ else:
89
+ body = await request.json()
90
+ model = body.get("model")
91
+ prompt = body.get("prompt")
92
+
93
+ if not model or not prompt:
94
+ raise HTTPException(status_code=400, detail="Missing model or prompt parameter")
95
+
96
+ logger.info(f"Received {request.method} request for model: {model}")
97
+
98
+ headers = {
99
+ "Cache-Control": "no-cache",
100
+ "Connection": "keep-alive",
101
+ "X-Accel-Buffering": "no" # Disable buffering for nginx
102
+ }
103
+
104
+ return StreamingResponse(
105
+ generate_stream(model, prompt),
106
+ media_type="text/event-stream",
107
+ headers=headers
108
+ )
109
+
110
+ except APIError as e:
111
+ error_response = handle_api_error(e)
112
+ raise HTTPException(
113
+ status_code=error_response["status_code"],
114
+ detail=error_response["detail"]
115
+ )
116
+ except Exception as e:
117
+ logger.error(f"Unexpected error in generate_response: {str(e)}")
118
+ raise HTTPException(status_code=500, detail="Internal server error")
119
+
120
+ if __name__ == "__main__":
121
+ import uvicorn
122
+ uvicorn.run(app, host=API_HOST, port=API_PORT)
core/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core package for the AI Text Generation API.
3
+
4
+ This package contains the core functionality of the API:
5
+ - Configuration management
6
+ - Key management
7
+ - Text generation
8
+ - Error handling
9
+ """
10
+
11
+ from core.config import *
12
+ from core.exceptions import *
13
+ from core.key_manager import key_manager
14
+ from core.text_generation import text_generator
15
+
16
+ __version__ = "1.0.0"
core/config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ # API Provider mapping - Maps model provider names to their API key column names
4
+ PROVIDER_MAP: Dict[str, str] = {
5
+ "GROQ MODELS": "GROQ API ", # Note the trailing space
6
+ "COHERE": "COHERE AI",
7
+ "SambaNova": "Samba api key",
8
+ "GEMINI": "GEMINI API",
9
+ # New providers will be automatically detected from MODELS.csv
10
+ # Their API key column names should match the format: "{PROVIDER} API" or "{PROVIDER} api key"
11
+ }
12
+
13
+ # API Configuration
14
+ API_HOST = "0.0.0.0"
15
+ API_PORT = 80
16
+
17
+ # CORS Configuration
18
+ CORS_SETTINGS = {
19
+ "allow_origins": ["*"],
20
+ "allow_credentials": True,
21
+ "allow_methods": ["*"],
22
+ "allow_headers": ["*"]
23
+ }
24
+
25
+ # Logging Configuration
26
+ LOG_CONFIG = {
27
+ "version": 1,
28
+ "disable_existing_loggers": False,
29
+ "formatters": {
30
+ "standard": {
31
+ "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
32
+ }
33
+ },
34
+ "handlers": {
35
+ "default": {
36
+ "level": "INFO",
37
+ "formatter": "standard",
38
+ "class": "logging.StreamHandler",
39
+ "stream": "ext://sys.stdout"
40
+ }
41
+ },
42
+ "loggers": {
43
+ "": {
44
+ "handlers": ["default"],
45
+ "level": "INFO",
46
+ "propagate": True
47
+ }
48
+ }
49
+ }
core/exceptions.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class APIError(Exception):
2
+ """Base exception for API related errors."""
3
+ pass
4
+
5
+ class APIKeyError(APIError):
6
+ """Exception raised for API key related errors."""
7
+ pass
8
+
9
+ class ModelError(APIError):
10
+ """Exception raised for model related errors."""
11
+ pass
12
+
13
+ class ProviderError(APIError):
14
+ """Exception raised for provider related errors."""
15
+ pass
16
+
17
+ class ValidationError(APIError):
18
+ """Exception raised for input validation errors."""
19
+ pass
20
+
21
+ def handle_api_error(error: APIError) -> dict:
22
+ """Convert API errors to response dictionaries."""
23
+ error_types = {
24
+ APIKeyError: (400, "API Key Error"),
25
+ ModelError: (400, "Model Error"),
26
+ ProviderError: (500, "Provider Error"),
27
+ ValidationError: (400, "Validation Error"),
28
+ APIError: (500, "Internal Server Error")
29
+ }
30
+
31
+ error_class = type(error)
32
+ status_code, error_type = error_types.get(error_class, (500, "Unknown Error"))
33
+
34
+ return {
35
+ "status_code": status_code,
36
+ "error_type": error_type,
37
+ "detail": str(error)
38
+ }
core/key_manager.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import Dict, List
3
+ import logging
4
+ import os
5
+ from core.config import PROVIDER_MAP
6
+ from core.exceptions import APIKeyError
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class KeyManager:
11
+ def __init__(self):
12
+ self.current_indices: Dict[str, int] = {}
13
+ self.api_keys: Dict[str, List[str]] = {}
14
+ self.models_map: Dict[str, str] = {}
15
+ self.provider_map = PROVIDER_MAP.copy()
16
+ self._load_models() # Load models first to detect providers
17
+ self._load_api_keys()
18
+
19
+ def _detect_api_key_column(self, provider: str, columns: List[str]) -> str:
20
+ """
21
+ Detect the API key column name for a provider.
22
+ Tries different common formats like '{PROVIDER} API' or '{PROVIDER} api key'.
23
+ """
24
+ provider_variants = [
25
+ f"{provider} API",
26
+ f"{provider} api",
27
+ f"{provider} API KEY",
28
+ f"{provider} api key",
29
+ provider.upper() + " API",
30
+ provider.lower() + " api",
31
+ ]
32
+
33
+ for variant in provider_variants:
34
+ for column in columns:
35
+ if column.strip().lower() == variant.lower():
36
+ return column
37
+
38
+ return None
39
+
40
+ def _load_api_keys(self) -> None:
41
+ """Load API keys from CSV file and initialize rotation indices."""
42
+ try:
43
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
44
+ keys_path = os.path.join(backend_dir, 'apikeys.csv')
45
+ keys_df = pd.read_csv(keys_path)
46
+
47
+ logger.info("Available API key columns: %s", keys_df.columns.tolist())
48
+
49
+ # Process both predefined and auto-detected providers
50
+ for provider in set(self.models_map.values()):
51
+ if provider not in self.provider_map:
52
+ # Try to detect API key column for new provider
53
+ api_column = self._detect_api_key_column(provider, keys_df.columns)
54
+ if api_column:
55
+ self.provider_map[provider] = api_column
56
+ logger.info(f"Auto-detected API column '{api_column}' for provider '{provider}'")
57
+
58
+ api_provider = self.provider_map.get(provider)
59
+ if api_provider and api_provider in keys_df.columns:
60
+ valid_keys = [key for key in keys_df[api_provider].dropna()]
61
+ logger.info(f"Found {len(valid_keys)} valid keys for {provider}")
62
+ if valid_keys:
63
+ self.api_keys[provider] = valid_keys
64
+ self.current_indices[provider] = 0
65
+ else:
66
+ logger.warning(f"No API keys found for provider: {provider}")
67
+ if api_provider:
68
+ logger.warning(f"Column '{api_provider}' not found in CSV")
69
+ logger.warning(f"Available columns: {keys_df.columns.tolist()}")
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error loading API keys: {str(e)}")
73
+ raise APIKeyError(f"Error loading API keys: {str(e)}")
74
+
75
+ def _load_models(self) -> None:
76
+ """Load models from CSV file and create model-to-provider mapping."""
77
+ try:
78
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
79
+ models_path = os.path.join(backend_dir, 'MODELS.csv')
80
+ models_df = pd.read_csv(models_path)
81
+
82
+ logger.info("Available model providers: %s", models_df.columns.tolist())
83
+ for column in models_df.columns:
84
+ provider = column.strip()
85
+ models = [model for model in models_df[column].dropna()]
86
+ logger.info(f"Loading {len(models)} models for provider {provider}")
87
+ for model in models:
88
+ self.models_map[model.strip()] = provider
89
+
90
+ logger.info(f"Detected providers: {set(self.models_map.values())}")
91
+
92
+ except Exception as e:
93
+ logger.error(f"Error loading models: {str(e)}")
94
+ raise APIKeyError(f"Error loading models: {str(e)}")
95
+
96
+ def get_next_key(self, provider: str) -> str:
97
+ """Get the next API key for the specified provider using rotation."""
98
+ if provider not in self.api_keys:
99
+ logger.error(f"No API keys found for provider: {provider}")
100
+ logger.error(f"Available providers: {list(self.api_keys.keys())}")
101
+ raise APIKeyError(f"No API keys found for provider: {provider}")
102
+
103
+ keys = self.api_keys[provider]
104
+ current_idx = self.current_indices[provider]
105
+ total_keys = len(keys)
106
+
107
+ key = keys[current_idx]
108
+ next_idx = (current_idx + 1) % total_keys
109
+ self.current_indices[provider] = next_idx
110
+
111
+ key_preview = f"{key[:4]}...{key[-4:]}"
112
+ logger.info(f"### API Key Rotation for {provider} ###")
113
+ logger.info(f"Using key #{current_idx + 1}/{total_keys} ({key_preview})")
114
+ logger.info(f"Next request will use key #{next_idx + 1}/{total_keys}")
115
+
116
+ return key
117
+
118
+ def get_provider_for_model(self, model: str) -> str:
119
+ """Get the provider name for a given model."""
120
+ model = model.strip()
121
+ if model not in self.models_map:
122
+ raise APIKeyError(f"Model not found: {model}")
123
+ return self.models_map[model]
124
+
125
+ def get_available_models(self) -> List[Dict[str, str]]:
126
+ """Get list of all available models with their providers."""
127
+ return [{"model": model, "provider": provider}
128
+ for model, provider in self.models_map.items()]
129
+
130
+ # Global instance
131
+ key_manager = KeyManager()
core/text_generation.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import AsyncGenerator, Dict, Any
2
+ import importlib
3
+ import os
4
+ import logging
5
+ from core.exceptions import ModelError, ProviderError
6
+ from core.key_manager import key_manager
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class TextGenerator:
11
+ """Handles text generation across different model providers."""
12
+
13
+ @staticmethod
14
+ async def generate_stream(model: str, prompt: str) -> AsyncGenerator[str, None]:
15
+ """
16
+ Generate streaming text responses using the specified model.
17
+
18
+ Args:
19
+ model (str): The name of the model to use
20
+ prompt (str): The input prompt for text generation
21
+
22
+ Yields:
23
+ str: Generated text chunks
24
+
25
+ Raises:
26
+ ModelError: If the model is not found or invalid
27
+ ProviderError: If there's an error with the provider
28
+ """
29
+ try:
30
+ # Get provider for the selected model
31
+ provider = key_manager.get_provider_for_model(model)
32
+ logger.info(f"Using provider {provider} for model {model}")
33
+
34
+ # Get API key
35
+ api_key = key_manager.get_next_key(provider)
36
+ logger.info(f"Retrieved API key for provider {provider}")
37
+
38
+ # Import provider module
39
+ try:
40
+ # Convert provider name to valid module name
41
+ module_name = provider.lower().split()[0] # Get first word in lowercase
42
+ logger.info(f"Importing module: models.{module_name}")
43
+
44
+ # Check if module exists
45
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
46
+ module_path = os.path.join(backend_dir, 'models', f'{module_name}.py')
47
+
48
+ if not os.path.exists(module_path):
49
+ raise ProviderError(f"Provider module not found at: {module_path}")
50
+
51
+ provider_module = importlib.import_module(f"models.{module_name}")
52
+
53
+ # Verify required functions exist
54
+ if not hasattr(provider_module, 'run_model_stream'):
55
+ raise ProviderError(f"Provider module {module_name} missing required function: run_model_stream")
56
+
57
+ except ImportError as e:
58
+ raise ProviderError(f"Failed to import provider module: {str(e)}")
59
+ except Exception as e:
60
+ raise ProviderError(f"Error loading provider module: {str(e)}")
61
+
62
+ # Generate text stream
63
+ async for chunk in provider_module.run_model_stream(api_key, model, prompt):
64
+ yield chunk
65
+
66
+ except Exception as e:
67
+ logger.error(f"Error in generate_stream: {str(e)}")
68
+ raise
69
+
70
+ @staticmethod
71
+ def get_available_models() -> Dict[str, Any]:
72
+ """
73
+ Get all available models and their providers.
74
+
75
+ Returns:
76
+ Dict[str, Any]: Dictionary containing model information
77
+ """
78
+ try:
79
+ models = key_manager.get_available_models()
80
+ # Return just the models array directly since that's what frontend expects
81
+ return models
82
+ except Exception as e:
83
+ logger.error(f"Error getting available models: {str(e)}")
84
+ raise ModelError(f"Failed to get available models: {str(e)}")
85
+
86
+ # Global instance
87
+ text_generator = TextGenerator()
models/cohere-test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cohere
2
+
3
+ co = cohere.ClientV2(api_key='NYEQhh6CwVF3dCo57n25032m4s0FurCYI73IYWb5')
4
+
5
+ response = co.chat_stream(
6
+ model="command-r-plus-08-2024",
7
+ messages=[{"role": "user", "content": "write a story about a drogon"}],
8
+ )
9
+
10
+ for event in response:
11
+ if event.type == "content-delta":
12
+ print(event.delta.message.content.text, end="")
models/cohere.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cohere
2
+ import asyncio
3
+ from typing import AsyncGenerator
4
+
5
+ async def run_model_stream(api_key: str, model: str, prompt: str):
6
+ """
7
+ Run the Cohere model with streaming response.
8
+
9
+ Args:
10
+ api_key: The API key to use for this request
11
+ model: The model name to use
12
+ prompt: The user's input prompt
13
+
14
+ Yields:
15
+ str: Chunks of the generated response
16
+ """
17
+ try:
18
+ client = cohere.Client(api_key=api_key)
19
+
20
+ # Create chat message with streaming
21
+ response = await asyncio.get_event_loop().run_in_executor(
22
+ None,
23
+ lambda: client.chat(
24
+ chat_history=[],
25
+ message=prompt,
26
+ model=model, # Use model name directly from MODELS.csv
27
+ stream=True,
28
+ temperature=0.7
29
+ )
30
+ )
31
+
32
+ # Process each chunk
33
+ for event in response:
34
+ if hasattr(event, 'text') and event.text:
35
+ # Use asyncio.sleep to prevent blocking
36
+ await asyncio.sleep(0)
37
+ yield event.text
38
+
39
+ except Exception as e:
40
+ raise Exception(f"Error with Cohere API: {str(e)}")
41
+
42
+ async def run_model(api_key: str, model: str, prompt: str) -> str:
43
+ """
44
+ Run the Cohere model with the provided API key and prompt (non-streaming).
45
+
46
+ Args:
47
+ api_key: The API key to use for this request
48
+ model: The model name to use
49
+ prompt: The user's input prompt
50
+
51
+ Returns:
52
+ str: The generated response
53
+ """
54
+ response = ""
55
+ async for chunk in run_model_stream(api_key, model, prompt):
56
+ response += chunk
57
+ return response
models/gemini.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ import asyncio
3
+ from typing import AsyncGenerator
4
+
5
+ async def run_model_stream(api_key: str, model: str, prompt: str):
6
+ """
7
+ Run the Gemini model with streaming response.
8
+
9
+ Args:
10
+ api_key: The API key to use for this request
11
+ model: The model name to use
12
+ prompt: The user's input prompt
13
+
14
+ Yields:
15
+ str: Chunks of the generated response
16
+ """
17
+ try:
18
+ # Configure the Gemini API
19
+ genai.configure(api_key=api_key)
20
+
21
+ # Initialize the model with name from MODELS.csv
22
+ model_instance = genai.GenerativeModel(model)
23
+
24
+ # Start the streaming response using async executor
25
+ response = await asyncio.get_event_loop().run_in_executor(
26
+ None,
27
+ lambda: model_instance.generate_content(
28
+ prompt,
29
+ stream=True
30
+ )
31
+ )
32
+
33
+ # Process chunks with async handling
34
+ for chunk in response:
35
+ if chunk.text:
36
+ # Use asyncio.sleep to prevent blocking
37
+ await asyncio.sleep(0)
38
+ yield chunk.text
39
+
40
+ except Exception as e:
41
+ raise Exception(f"Error with Gemini API: {str(e)}")
42
+
43
+ async def run_model(api_key: str, model: str, prompt: str) -> str:
44
+ """
45
+ Run the Gemini model with the provided API key and prompt (non-streaming).
46
+
47
+ Args:
48
+ api_key: The API key to use for this request
49
+ model: The model name to use
50
+ prompt: The user's input prompt
51
+
52
+ Returns:
53
+ str: The generated response
54
+ """
55
+ response = ""
56
+ async for chunk in run_model_stream(api_key, model, prompt):
57
+ response += chunk
58
+ return response
models/groq.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import AsyncOpenAI
2
+ import httpx
3
+
4
+ async def run_model_stream(api_key: str, model: str, prompt: str):
5
+ """
6
+ Run the Groq model with streaming response.
7
+
8
+ Args:
9
+ api_key: The API key to use for this request
10
+ model: The model name to use
11
+ prompt: The user's input prompt
12
+
13
+ Yields:
14
+ str: Chunks of the generated response
15
+ """
16
+ try:
17
+ # Initialize AsyncOpenAI client with specific configuration
18
+ client = AsyncOpenAI(
19
+ api_key=api_key,
20
+ base_url="https://api.groq.com/openai/v1",
21
+ http_client=httpx.AsyncClient(verify=True) # Async client
22
+ )
23
+
24
+ completion = await client.chat.completions.create(
25
+ messages=[
26
+ {
27
+ "role": "user",
28
+ "content": prompt
29
+ }
30
+ ],
31
+ model=model,
32
+ stream=True
33
+ )
34
+
35
+ async for chunk in completion:
36
+ if chunk.choices[0].delta.content is not None:
37
+ yield chunk.choices[0].delta.content
38
+
39
+ except Exception as e:
40
+ raise Exception(f"Error with Groq API: {str(e)}")
41
+
42
+ async def run_model(api_key: str, model: str, prompt: str) -> str:
43
+ """
44
+ Run the Groq model with the provided API key and prompt (non-streaming).
45
+
46
+ Args:
47
+ api_key: The API key to use for this request
48
+ model: The model name to use
49
+ prompt: The user's input prompt
50
+
51
+ Returns:
52
+ str: The generated response
53
+ """
54
+ try:
55
+ # Initialize AsyncOpenAI client with specific configuration
56
+ client = AsyncOpenAI(
57
+ api_key=api_key,
58
+ base_url="https://api.groq.com/openai/v1",
59
+ http_client=httpx.AsyncClient(verify=True) # Async client
60
+ )
61
+
62
+ chat_completion = await client.chat.completions.create(
63
+ messages=[
64
+ {
65
+ "role": "user",
66
+ "content": prompt
67
+ }
68
+ ],
69
+ model=model
70
+ )
71
+
72
+ return chat_completion.choices[0].message.content
73
+
74
+ except Exception as e:
75
+ raise Exception(f"Error with Groq API: {str(e)}")
models/mistral.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from mistralai import Mistral
3
+
4
+ api_key = "gZkpXZUIvz1ryxpyODGmsommNbryox2s"
5
+ model = "mistral-large-latest"
6
+
7
+ client = Mistral(api_key=api_key)
8
+
9
+ stream_response = client.chat.stream(
10
+ model = model,
11
+ messages = [
12
+ {
13
+ "role": "user",
14
+ "content": "What is the best French cheese?",
15
+ },
16
+ ]
17
+ )
18
+
19
+ for chunk in stream_response:
20
+ print(chunk.data.choices[0].delta.content)
models/sambanova.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import json
3
+ from typing import AsyncGenerator
4
+
5
+ async def run_model_stream(api_key: str, model: str, prompt: str):
6
+ """
7
+ Run the SambaNova model with streaming response.
8
+
9
+ Args:
10
+ api_key: The API key to use for this request
11
+ model: The model name to use
12
+ prompt: The user's input prompt
13
+
14
+ Yields:
15
+ str: Chunks of the generated response
16
+ """
17
+ try:
18
+ # Configure HTTP client with appropriate headers and SSL settings
19
+ headers = {
20
+ "Authorization": f"Bearer {api_key}",
21
+ "Content-Type": "application/json",
22
+ "Accept": "text/event-stream"
23
+ }
24
+
25
+ async with httpx.AsyncClient(
26
+ base_url="https://api.sambanova.ai/v1",
27
+ headers=headers,
28
+ verify=True,
29
+ timeout=httpx.Timeout(60.0, read=300.0)
30
+ ) as client:
31
+ # Make streaming request
32
+ async with client.stream(
33
+ "POST",
34
+ "/chat/completions",
35
+ json={
36
+ "model": model, # Use model name directly from MODELS.csv
37
+ "messages": [
38
+ {
39
+ "role": "user",
40
+ "content": prompt
41
+ }
42
+ ],
43
+ "stream": True,
44
+ "temperature": 0.7,
45
+ "max_tokens": 2048
46
+ }
47
+ ) as response:
48
+ response.raise_for_status()
49
+
50
+ async for line in response.aiter_lines():
51
+ line = line.strip()
52
+ if not line:
53
+ continue
54
+
55
+ if line.startswith("data: "):
56
+ data = line[6:].strip()
57
+ if data == "[DONE]":
58
+ break
59
+
60
+ try:
61
+ chunk_data = json.loads(data)
62
+ if chunk_data.get("choices") and chunk_data["choices"][0].get("delta"):
63
+ content = chunk_data["choices"][0]["delta"].get("content")
64
+ if content:
65
+ yield content
66
+ except Exception as e:
67
+ print(f"Error parsing chunk: {e}")
68
+ continue
69
+
70
+ except httpx.HTTPError as e:
71
+ raise Exception(f"HTTP error with SambaNova API: {str(e)}")
72
+ except Exception as e:
73
+ raise Exception(f"Error with SambaNova API: {str(e)}")
74
+
75
+ async def run_model(api_key: str, model: str, prompt: str) -> str:
76
+ """
77
+ Run the SambaNova model with the provided API key and prompt (non-streaming).
78
+
79
+ Args:
80
+ api_key: The API key to use for this request
81
+ model: The model name to use
82
+ prompt: The user's input prompt
83
+
84
+ Returns:
85
+ str: The generated response
86
+ """
87
+ response_text = ""
88
+ async for chunk in run_model_stream(api_key, model, prompt):
89
+ response_text += chunk
90
+ return response_text
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.110.0
2
+ uvicorn==0.27.1
3
+ python-multipart==0.0.9
4
+ pandas==2.2.1
5
+ openai==1.13.3
6
+ google-generativeai==0.3.2
7
+ cohere==4.51
8
+ requests==2.31.0
utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import Dict, List, Optional
3
+ import os
4
+ import logging
5
+
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class KeyRotator:
11
+ def __init__(self):
12
+ self.current_indices: Dict[str, int] = {}
13
+ self.api_keys: Dict[str, List[str]] = {}
14
+ self.models_map: Dict[str, str] = {}
15
+ self._load_api_keys()
16
+ self._load_models()
17
+
18
+ def _load_api_keys(self) -> None:
19
+ """Load API keys from CSV file and initialize rotation indices."""
20
+ try:
21
+ keys_df = pd.read_csv('apikeys.csv')
22
+ # Map model provider names to API key provider names
23
+ provider_map = {
24
+ "GROQ MODELS": "GROQ API ", # Note the trailing space
25
+ "COHERE": "COHERE AI",
26
+ "SambaNova": "Samba api key",
27
+ "GEMINI": "GEMINI API"
28
+ }
29
+
30
+ logger.info("Available API key columns: %s", keys_df.columns.tolist())
31
+
32
+ for model_provider, api_provider in provider_map.items():
33
+ logger.info(f"Processing provider mapping: {model_provider} -> {api_provider}")
34
+ if api_provider in keys_df.columns:
35
+ valid_keys = [key for key in keys_df[api_provider].dropna()]
36
+ logger.info(f"Found {len(valid_keys)} valid keys for {model_provider}")
37
+ if valid_keys:
38
+ self.api_keys[model_provider] = valid_keys
39
+ self.current_indices[model_provider] = 0
40
+ else:
41
+ logger.warning(f"API provider column {api_provider} not found in CSV")
42
+ logger.warning(f"Available columns: {keys_df.columns.tolist()}")
43
+ except Exception as e:
44
+ logger.error(f"Error loading API keys: {str(e)}")
45
+ raise Exception(f"Error loading API keys: {str(e)}")
46
+
47
+ def _load_models(self) -> None:
48
+ """Load models from CSV file and create model-to-provider mapping."""
49
+ try:
50
+ models_df = pd.read_csv('MODELS.csv')
51
+ logger.info("Available model providers: %s", models_df.columns.tolist())
52
+ for column in models_df.columns:
53
+ provider = column.strip()
54
+ models = [model for model in models_df[column].dropna()]
55
+ logger.info(f"Loading {len(models)} models for provider {provider}")
56
+ for model in models:
57
+ self.models_map[model.strip()] = provider
58
+ except Exception as e:
59
+ logger.error(f"Error loading models: {str(e)}")
60
+ raise Exception(f"Error loading models: {str(e)}")
61
+
62
+ def get_next_key(self, provider: str) -> str:
63
+ """Get the next API key for the specified provider using rotation."""
64
+ if provider not in self.api_keys:
65
+ logger.error(f"No API keys found for provider: {provider}")
66
+ logger.error(f"Available providers: {list(self.api_keys.keys())}")
67
+ raise ValueError(f"No API keys found for provider: {provider}")
68
+
69
+ keys = self.api_keys[provider]
70
+ current_idx = self.current_indices[provider]
71
+ total_keys = len(keys)
72
+
73
+ # Get current key and update index for next time
74
+ key = keys[current_idx]
75
+ next_idx = (current_idx + 1) % total_keys
76
+ self.current_indices[provider] = next_idx
77
+
78
+ # Log key rotation info with more details
79
+ key_preview = f"{key[:4]}...{key[-4:]}"
80
+ logger.info(f"### API Key Rotation for {provider} ###")
81
+ logger.info(f"Using key #{current_idx + 1}/{total_keys} ({key_preview})")
82
+ logger.info(f"Next request will use key #{next_idx + 1}/{total_keys}")
83
+ logger.info(f"#################################")
84
+
85
+ return key
86
+
87
+ def get_provider_for_model(self, model: str) -> str:
88
+ """Get the provider name for a given model."""
89
+ model = model.strip()
90
+ if model not in self.models_map:
91
+ raise ValueError(f"Model not found: {model}")
92
+ return self.models_map[model]
93
+
94
+ def get_available_models(self) -> List[Dict[str, str]]:
95
+ """Get list of all available models with their providers."""
96
+ return [{"model": model, "provider": provider}
97
+ for model, provider in self.models_map.items()]
98
+
99
+ # Global instance
100
+ key_rotator = KeyRotator()