Spaces:
Sleeping
Sleeping
Add complete backend application with all dependencies
Browse files- .gitignore +10 -0
- Dockerfile +32 -0
- MODELS.csv +14 -0
- apikeys.csv +6 -0
- app.py +122 -0
- core/__init__.py +16 -0
- core/config.py +49 -0
- core/exceptions.py +38 -0
- core/key_manager.py +131 -0
- core/text_generation.py +87 -0
- models/cohere-test.py +12 -0
- models/cohere.py +57 -0
- models/gemini.py +58 -0
- models/groq.py +75 -0
- models/mistral.py +20 -0
- models/sambanova.py +90 -0
- requirements.txt +8 -0
- utils.py +100 -0
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*$py.class
|
4 |
+
.env
|
5 |
+
.venv
|
6 |
+
env/
|
7 |
+
venv/
|
8 |
+
ENV/
|
9 |
+
*.log
|
10 |
+
.DS_Store
|
Dockerfile
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
build-essential \
|
8 |
+
&& rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
# Create a non-root user and switch to it
|
11 |
+
RUN useradd -m -u 1000 user
|
12 |
+
USER user
|
13 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
14 |
+
ENV PYTHONPATH="/app:$PYTHONPATH"
|
15 |
+
|
16 |
+
# Copy requirements first to leverage Docker cache
|
17 |
+
COPY --chown=user requirements.txt .
|
18 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
19 |
+
|
20 |
+
# Copy the application code
|
21 |
+
COPY --chown=user app.py .
|
22 |
+
COPY --chown=user apikeys.csv .
|
23 |
+
COPY --chown=user MODELS.csv .
|
24 |
+
COPY --chown=user utils.py .
|
25 |
+
COPY --chown=user core ./core
|
26 |
+
COPY --chown=user models ./models
|
27 |
+
|
28 |
+
# Expose the port
|
29 |
+
EXPOSE 7860
|
30 |
+
|
31 |
+
# Start the application
|
32 |
+
CMD ["python", "app.py"]
|
MODELS.csv
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GROQ MODELS,COHERE ,SambaNova,GEMINI
|
2 |
+
deepseek-r1-distill-llama-70b,Command R,DeepSeek-R1-Distill-Llama-70B,gemini-2.0-flash
|
3 |
+
llama-3.3-70b-versatile,Command R7B,Llama-3.1-Swallow-70B-Instruct-v0.3,gemini-2.0-flash-lite
|
4 |
+
llama-3.3-70b-specdec,Command R+,Llama-3.1-Swallow-8B-Instruct-v0.3,gemini-2.0-flash-thinking-exp-01-21
|
5 |
+
llama-3.2-1b-preview,,Llama-3.1-Tulu-3-405B,gemini-1.5-flash
|
6 |
+
llama-3.2-3b-preview,,Meta-Llama-3.1-405B-Instruct,gemini-1.5-flash-8b
|
7 |
+
llama-3.1-8b-instant,,Meta-Llama-3.1-70B-Instruct,
|
8 |
+
llama3-70b-8192,,Meta-Llama-3.1-8B-Instruct,
|
9 |
+
llama3-8b-8192,,Meta-Llama-3.2-1B-Instruct,
|
10 |
+
llama-guard-3-8b,,Meta-Llama-3.2-3B-Instruct,
|
11 |
+
mixtral-8x7b-32768,,Meta-Llama-3.3-70B-Instruct,
|
12 |
+
gemma2-9b-it,,Qwen2.5-72B-Instruct,
|
13 |
+
llama-3.2-11b-vision-preview,,Qwen2.5-Coder-32B-Instruct,
|
14 |
+
llama-3.2-90b-vision-preview,,QwQ-32B-Preview,
|
apikeys.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GROQ API ,GEMINI API,COHERE AI,Samba api key
|
2 |
+
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,AIzaSyA5XHRFKUIAKKmvPyZA50H9bmXGuio22_o,GEHqd8ojtfK7Q2qTSrH1RNbLPceIXgdtKtsaG1Io,d7277c97-5abf-4f00-afc4-a05f2283b1c9
|
3 |
+
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,AIzaSyCFWSxJ8Lwd8eeSrLfEyepU00_TuoGr3YE,IAcJ37niDUV3AuIzwzRLYnSTkrgou4VGKhgAWud0,7952d283-92c9-42aa-8511-52729a5375b3
|
4 |
+
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,AIzaSyDw2aCCLrBn2fQIkfjo3mHoNVaOMZEmQPg,EN5HVdaRlruyxpQkMhALF2ViwTb2ZOyfN5oBZnvf,68fa8ba4-5329-493a-a733-495d5fa72386
|
5 |
+
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,AIzaSyBbHEMA7zPTnfw3VAmYGVsQQcdAxO1O__Q,FfBExN70G0PAun5MRvw29rbq3soCJgwBb4TsVrFe,52de8b5e-4a32-4c6f-983f-5af048d14d96
|
6 |
+
,AIzaSyBSjaqgysbmd74zXax-mQMuR4YeHjUH124,w4EAnXvuW0G3j44qPwoGaMqSaMr03EWOugidAGMs,4cb02d31-2b3f-4e94-b12d-4d694b6f0c48
|
app.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Request
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
from fastapi.responses import FileResponse
|
6 |
+
from pydantic import BaseModel
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import logging.config
|
10 |
+
import os
|
11 |
+
|
12 |
+
from core.config import API_HOST, API_PORT, CORS_SETTINGS, LOG_CONFIG
|
13 |
+
from core.exceptions import APIError, handle_api_error
|
14 |
+
from core.text_generation import text_generator
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.config.dictConfig(LOG_CONFIG)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
app = FastAPI(title="AI Text Generation API",
|
21 |
+
description="API for text generation using multiple AI providers",
|
22 |
+
version="1.0.0")
|
23 |
+
|
24 |
+
# Enable CORS with specific headers for SSE
|
25 |
+
app.add_middleware(
|
26 |
+
CORSMiddleware,
|
27 |
+
allow_origins=["*"], # Update this in production
|
28 |
+
allow_credentials=True,
|
29 |
+
allow_methods=["*"],
|
30 |
+
allow_headers=["*"],
|
31 |
+
expose_headers=["Content-Type", "Cache-Control"]
|
32 |
+
)
|
33 |
+
|
34 |
+
# Mount static files
|
35 |
+
frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'frontend')
|
36 |
+
app.mount("/static", StaticFiles(directory=frontend_dir), name="static")
|
37 |
+
|
38 |
+
class PromptRequest(BaseModel):
|
39 |
+
model: str
|
40 |
+
prompt: str
|
41 |
+
|
42 |
+
@app.get("/")
|
43 |
+
async def read_root():
|
44 |
+
"""Serve the frontend HTML."""
|
45 |
+
return FileResponse(os.path.join(frontend_dir, 'index.html'))
|
46 |
+
|
47 |
+
@app.get("/models")
|
48 |
+
async def get_models():
|
49 |
+
"""Get list of all available models."""
|
50 |
+
try:
|
51 |
+
# Return models as a JSON array
|
52 |
+
return JSONResponse(content=text_generator.get_available_models())
|
53 |
+
except APIError as e:
|
54 |
+
error_response = handle_api_error(e)
|
55 |
+
raise HTTPException(
|
56 |
+
status_code=error_response["status_code"],
|
57 |
+
detail=error_response["detail"]
|
58 |
+
)
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Unexpected error in get_models: {str(e)}")
|
61 |
+
raise HTTPException(status_code=500, detail="Internal server error")
|
62 |
+
|
63 |
+
async def generate_stream(model: str, prompt: str):
|
64 |
+
"""Stream generator for text generation."""
|
65 |
+
try:
|
66 |
+
async for chunk in text_generator.generate_stream(model, prompt):
|
67 |
+
# Add extra newline to ensure proper event separation
|
68 |
+
yield f"data: {json.dumps({'content': chunk})}\n\n"
|
69 |
+
except APIError as e:
|
70 |
+
error_response = handle_api_error(e)
|
71 |
+
yield f"data: {json.dumps({'error': error_response['detail']})}\n\n"
|
72 |
+
except Exception as e:
|
73 |
+
logger.error(f"Unexpected error in generate_stream: {str(e)}")
|
74 |
+
yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
|
75 |
+
finally:
|
76 |
+
yield "data: [DONE]\n\n"
|
77 |
+
|
78 |
+
@app.get("/generate")
|
79 |
+
@app.post("/generate")
|
80 |
+
async def generate_response(request: Request):
|
81 |
+
"""Generate response using selected model (supports both GET and POST)."""
|
82 |
+
try:
|
83 |
+
# Handle both GET and POST methods
|
84 |
+
if request.method == "GET":
|
85 |
+
params = dict(request.query_params)
|
86 |
+
model = params.get("model")
|
87 |
+
prompt = params.get("prompt")
|
88 |
+
else:
|
89 |
+
body = await request.json()
|
90 |
+
model = body.get("model")
|
91 |
+
prompt = body.get("prompt")
|
92 |
+
|
93 |
+
if not model or not prompt:
|
94 |
+
raise HTTPException(status_code=400, detail="Missing model or prompt parameter")
|
95 |
+
|
96 |
+
logger.info(f"Received {request.method} request for model: {model}")
|
97 |
+
|
98 |
+
headers = {
|
99 |
+
"Cache-Control": "no-cache",
|
100 |
+
"Connection": "keep-alive",
|
101 |
+
"X-Accel-Buffering": "no" # Disable buffering for nginx
|
102 |
+
}
|
103 |
+
|
104 |
+
return StreamingResponse(
|
105 |
+
generate_stream(model, prompt),
|
106 |
+
media_type="text/event-stream",
|
107 |
+
headers=headers
|
108 |
+
)
|
109 |
+
|
110 |
+
except APIError as e:
|
111 |
+
error_response = handle_api_error(e)
|
112 |
+
raise HTTPException(
|
113 |
+
status_code=error_response["status_code"],
|
114 |
+
detail=error_response["detail"]
|
115 |
+
)
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(f"Unexpected error in generate_response: {str(e)}")
|
118 |
+
raise HTTPException(status_code=500, detail="Internal server error")
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
import uvicorn
|
122 |
+
uvicorn.run(app, host=API_HOST, port=API_PORT)
|
core/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Core package for the AI Text Generation API.
|
3 |
+
|
4 |
+
This package contains the core functionality of the API:
|
5 |
+
- Configuration management
|
6 |
+
- Key management
|
7 |
+
- Text generation
|
8 |
+
- Error handling
|
9 |
+
"""
|
10 |
+
|
11 |
+
from core.config import *
|
12 |
+
from core.exceptions import *
|
13 |
+
from core.key_manager import key_manager
|
14 |
+
from core.text_generation import text_generator
|
15 |
+
|
16 |
+
__version__ = "1.0.0"
|
core/config.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
# API Provider mapping - Maps model provider names to their API key column names
|
4 |
+
PROVIDER_MAP: Dict[str, str] = {
|
5 |
+
"GROQ MODELS": "GROQ API ", # Note the trailing space
|
6 |
+
"COHERE": "COHERE AI",
|
7 |
+
"SambaNova": "Samba api key",
|
8 |
+
"GEMINI": "GEMINI API",
|
9 |
+
# New providers will be automatically detected from MODELS.csv
|
10 |
+
# Their API key column names should match the format: "{PROVIDER} API" or "{PROVIDER} api key"
|
11 |
+
}
|
12 |
+
|
13 |
+
# API Configuration
|
14 |
+
API_HOST = "0.0.0.0"
|
15 |
+
API_PORT = 80
|
16 |
+
|
17 |
+
# CORS Configuration
|
18 |
+
CORS_SETTINGS = {
|
19 |
+
"allow_origins": ["*"],
|
20 |
+
"allow_credentials": True,
|
21 |
+
"allow_methods": ["*"],
|
22 |
+
"allow_headers": ["*"]
|
23 |
+
}
|
24 |
+
|
25 |
+
# Logging Configuration
|
26 |
+
LOG_CONFIG = {
|
27 |
+
"version": 1,
|
28 |
+
"disable_existing_loggers": False,
|
29 |
+
"formatters": {
|
30 |
+
"standard": {
|
31 |
+
"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"handlers": {
|
35 |
+
"default": {
|
36 |
+
"level": "INFO",
|
37 |
+
"formatter": "standard",
|
38 |
+
"class": "logging.StreamHandler",
|
39 |
+
"stream": "ext://sys.stdout"
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"loggers": {
|
43 |
+
"": {
|
44 |
+
"handlers": ["default"],
|
45 |
+
"level": "INFO",
|
46 |
+
"propagate": True
|
47 |
+
}
|
48 |
+
}
|
49 |
+
}
|
core/exceptions.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class APIError(Exception):
|
2 |
+
"""Base exception for API related errors."""
|
3 |
+
pass
|
4 |
+
|
5 |
+
class APIKeyError(APIError):
|
6 |
+
"""Exception raised for API key related errors."""
|
7 |
+
pass
|
8 |
+
|
9 |
+
class ModelError(APIError):
|
10 |
+
"""Exception raised for model related errors."""
|
11 |
+
pass
|
12 |
+
|
13 |
+
class ProviderError(APIError):
|
14 |
+
"""Exception raised for provider related errors."""
|
15 |
+
pass
|
16 |
+
|
17 |
+
class ValidationError(APIError):
|
18 |
+
"""Exception raised for input validation errors."""
|
19 |
+
pass
|
20 |
+
|
21 |
+
def handle_api_error(error: APIError) -> dict:
|
22 |
+
"""Convert API errors to response dictionaries."""
|
23 |
+
error_types = {
|
24 |
+
APIKeyError: (400, "API Key Error"),
|
25 |
+
ModelError: (400, "Model Error"),
|
26 |
+
ProviderError: (500, "Provider Error"),
|
27 |
+
ValidationError: (400, "Validation Error"),
|
28 |
+
APIError: (500, "Internal Server Error")
|
29 |
+
}
|
30 |
+
|
31 |
+
error_class = type(error)
|
32 |
+
status_code, error_type = error_types.get(error_class, (500, "Unknown Error"))
|
33 |
+
|
34 |
+
return {
|
35 |
+
"status_code": status_code,
|
36 |
+
"error_type": error_type,
|
37 |
+
"detail": str(error)
|
38 |
+
}
|
core/key_manager.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from typing import Dict, List
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from core.config import PROVIDER_MAP
|
6 |
+
from core.exceptions import APIKeyError
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class KeyManager:
|
11 |
+
def __init__(self):
|
12 |
+
self.current_indices: Dict[str, int] = {}
|
13 |
+
self.api_keys: Dict[str, List[str]] = {}
|
14 |
+
self.models_map: Dict[str, str] = {}
|
15 |
+
self.provider_map = PROVIDER_MAP.copy()
|
16 |
+
self._load_models() # Load models first to detect providers
|
17 |
+
self._load_api_keys()
|
18 |
+
|
19 |
+
def _detect_api_key_column(self, provider: str, columns: List[str]) -> str:
|
20 |
+
"""
|
21 |
+
Detect the API key column name for a provider.
|
22 |
+
Tries different common formats like '{PROVIDER} API' or '{PROVIDER} api key'.
|
23 |
+
"""
|
24 |
+
provider_variants = [
|
25 |
+
f"{provider} API",
|
26 |
+
f"{provider} api",
|
27 |
+
f"{provider} API KEY",
|
28 |
+
f"{provider} api key",
|
29 |
+
provider.upper() + " API",
|
30 |
+
provider.lower() + " api",
|
31 |
+
]
|
32 |
+
|
33 |
+
for variant in provider_variants:
|
34 |
+
for column in columns:
|
35 |
+
if column.strip().lower() == variant.lower():
|
36 |
+
return column
|
37 |
+
|
38 |
+
return None
|
39 |
+
|
40 |
+
def _load_api_keys(self) -> None:
|
41 |
+
"""Load API keys from CSV file and initialize rotation indices."""
|
42 |
+
try:
|
43 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
44 |
+
keys_path = os.path.join(backend_dir, 'apikeys.csv')
|
45 |
+
keys_df = pd.read_csv(keys_path)
|
46 |
+
|
47 |
+
logger.info("Available API key columns: %s", keys_df.columns.tolist())
|
48 |
+
|
49 |
+
# Process both predefined and auto-detected providers
|
50 |
+
for provider in set(self.models_map.values()):
|
51 |
+
if provider not in self.provider_map:
|
52 |
+
# Try to detect API key column for new provider
|
53 |
+
api_column = self._detect_api_key_column(provider, keys_df.columns)
|
54 |
+
if api_column:
|
55 |
+
self.provider_map[provider] = api_column
|
56 |
+
logger.info(f"Auto-detected API column '{api_column}' for provider '{provider}'")
|
57 |
+
|
58 |
+
api_provider = self.provider_map.get(provider)
|
59 |
+
if api_provider and api_provider in keys_df.columns:
|
60 |
+
valid_keys = [key for key in keys_df[api_provider].dropna()]
|
61 |
+
logger.info(f"Found {len(valid_keys)} valid keys for {provider}")
|
62 |
+
if valid_keys:
|
63 |
+
self.api_keys[provider] = valid_keys
|
64 |
+
self.current_indices[provider] = 0
|
65 |
+
else:
|
66 |
+
logger.warning(f"No API keys found for provider: {provider}")
|
67 |
+
if api_provider:
|
68 |
+
logger.warning(f"Column '{api_provider}' not found in CSV")
|
69 |
+
logger.warning(f"Available columns: {keys_df.columns.tolist()}")
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"Error loading API keys: {str(e)}")
|
73 |
+
raise APIKeyError(f"Error loading API keys: {str(e)}")
|
74 |
+
|
75 |
+
def _load_models(self) -> None:
|
76 |
+
"""Load models from CSV file and create model-to-provider mapping."""
|
77 |
+
try:
|
78 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
79 |
+
models_path = os.path.join(backend_dir, 'MODELS.csv')
|
80 |
+
models_df = pd.read_csv(models_path)
|
81 |
+
|
82 |
+
logger.info("Available model providers: %s", models_df.columns.tolist())
|
83 |
+
for column in models_df.columns:
|
84 |
+
provider = column.strip()
|
85 |
+
models = [model for model in models_df[column].dropna()]
|
86 |
+
logger.info(f"Loading {len(models)} models for provider {provider}")
|
87 |
+
for model in models:
|
88 |
+
self.models_map[model.strip()] = provider
|
89 |
+
|
90 |
+
logger.info(f"Detected providers: {set(self.models_map.values())}")
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"Error loading models: {str(e)}")
|
94 |
+
raise APIKeyError(f"Error loading models: {str(e)}")
|
95 |
+
|
96 |
+
def get_next_key(self, provider: str) -> str:
|
97 |
+
"""Get the next API key for the specified provider using rotation."""
|
98 |
+
if provider not in self.api_keys:
|
99 |
+
logger.error(f"No API keys found for provider: {provider}")
|
100 |
+
logger.error(f"Available providers: {list(self.api_keys.keys())}")
|
101 |
+
raise APIKeyError(f"No API keys found for provider: {provider}")
|
102 |
+
|
103 |
+
keys = self.api_keys[provider]
|
104 |
+
current_idx = self.current_indices[provider]
|
105 |
+
total_keys = len(keys)
|
106 |
+
|
107 |
+
key = keys[current_idx]
|
108 |
+
next_idx = (current_idx + 1) % total_keys
|
109 |
+
self.current_indices[provider] = next_idx
|
110 |
+
|
111 |
+
key_preview = f"{key[:4]}...{key[-4:]}"
|
112 |
+
logger.info(f"### API Key Rotation for {provider} ###")
|
113 |
+
logger.info(f"Using key #{current_idx + 1}/{total_keys} ({key_preview})")
|
114 |
+
logger.info(f"Next request will use key #{next_idx + 1}/{total_keys}")
|
115 |
+
|
116 |
+
return key
|
117 |
+
|
118 |
+
def get_provider_for_model(self, model: str) -> str:
|
119 |
+
"""Get the provider name for a given model."""
|
120 |
+
model = model.strip()
|
121 |
+
if model not in self.models_map:
|
122 |
+
raise APIKeyError(f"Model not found: {model}")
|
123 |
+
return self.models_map[model]
|
124 |
+
|
125 |
+
def get_available_models(self) -> List[Dict[str, str]]:
|
126 |
+
"""Get list of all available models with their providers."""
|
127 |
+
return [{"model": model, "provider": provider}
|
128 |
+
for model, provider in self.models_map.items()]
|
129 |
+
|
130 |
+
# Global instance
|
131 |
+
key_manager = KeyManager()
|
core/text_generation.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import AsyncGenerator, Dict, Any
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
from core.exceptions import ModelError, ProviderError
|
6 |
+
from core.key_manager import key_manager
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class TextGenerator:
|
11 |
+
"""Handles text generation across different model providers."""
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
async def generate_stream(model: str, prompt: str) -> AsyncGenerator[str, None]:
|
15 |
+
"""
|
16 |
+
Generate streaming text responses using the specified model.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model (str): The name of the model to use
|
20 |
+
prompt (str): The input prompt for text generation
|
21 |
+
|
22 |
+
Yields:
|
23 |
+
str: Generated text chunks
|
24 |
+
|
25 |
+
Raises:
|
26 |
+
ModelError: If the model is not found or invalid
|
27 |
+
ProviderError: If there's an error with the provider
|
28 |
+
"""
|
29 |
+
try:
|
30 |
+
# Get provider for the selected model
|
31 |
+
provider = key_manager.get_provider_for_model(model)
|
32 |
+
logger.info(f"Using provider {provider} for model {model}")
|
33 |
+
|
34 |
+
# Get API key
|
35 |
+
api_key = key_manager.get_next_key(provider)
|
36 |
+
logger.info(f"Retrieved API key for provider {provider}")
|
37 |
+
|
38 |
+
# Import provider module
|
39 |
+
try:
|
40 |
+
# Convert provider name to valid module name
|
41 |
+
module_name = provider.lower().split()[0] # Get first word in lowercase
|
42 |
+
logger.info(f"Importing module: models.{module_name}")
|
43 |
+
|
44 |
+
# Check if module exists
|
45 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
46 |
+
module_path = os.path.join(backend_dir, 'models', f'{module_name}.py')
|
47 |
+
|
48 |
+
if not os.path.exists(module_path):
|
49 |
+
raise ProviderError(f"Provider module not found at: {module_path}")
|
50 |
+
|
51 |
+
provider_module = importlib.import_module(f"models.{module_name}")
|
52 |
+
|
53 |
+
# Verify required functions exist
|
54 |
+
if not hasattr(provider_module, 'run_model_stream'):
|
55 |
+
raise ProviderError(f"Provider module {module_name} missing required function: run_model_stream")
|
56 |
+
|
57 |
+
except ImportError as e:
|
58 |
+
raise ProviderError(f"Failed to import provider module: {str(e)}")
|
59 |
+
except Exception as e:
|
60 |
+
raise ProviderError(f"Error loading provider module: {str(e)}")
|
61 |
+
|
62 |
+
# Generate text stream
|
63 |
+
async for chunk in provider_module.run_model_stream(api_key, model, prompt):
|
64 |
+
yield chunk
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Error in generate_stream: {str(e)}")
|
68 |
+
raise
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def get_available_models() -> Dict[str, Any]:
|
72 |
+
"""
|
73 |
+
Get all available models and their providers.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Dict[str, Any]: Dictionary containing model information
|
77 |
+
"""
|
78 |
+
try:
|
79 |
+
models = key_manager.get_available_models()
|
80 |
+
# Return just the models array directly since that's what frontend expects
|
81 |
+
return models
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error getting available models: {str(e)}")
|
84 |
+
raise ModelError(f"Failed to get available models: {str(e)}")
|
85 |
+
|
86 |
+
# Global instance
|
87 |
+
text_generator = TextGenerator()
|
models/cohere-test.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cohere
|
2 |
+
|
3 |
+
co = cohere.ClientV2(api_key='NYEQhh6CwVF3dCo57n25032m4s0FurCYI73IYWb5')
|
4 |
+
|
5 |
+
response = co.chat_stream(
|
6 |
+
model="command-r-plus-08-2024",
|
7 |
+
messages=[{"role": "user", "content": "write a story about a drogon"}],
|
8 |
+
)
|
9 |
+
|
10 |
+
for event in response:
|
11 |
+
if event.type == "content-delta":
|
12 |
+
print(event.delta.message.content.text, end="")
|
models/cohere.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cohere
|
2 |
+
import asyncio
|
3 |
+
from typing import AsyncGenerator
|
4 |
+
|
5 |
+
async def run_model_stream(api_key: str, model: str, prompt: str):
|
6 |
+
"""
|
7 |
+
Run the Cohere model with streaming response.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
api_key: The API key to use for this request
|
11 |
+
model: The model name to use
|
12 |
+
prompt: The user's input prompt
|
13 |
+
|
14 |
+
Yields:
|
15 |
+
str: Chunks of the generated response
|
16 |
+
"""
|
17 |
+
try:
|
18 |
+
client = cohere.Client(api_key=api_key)
|
19 |
+
|
20 |
+
# Create chat message with streaming
|
21 |
+
response = await asyncio.get_event_loop().run_in_executor(
|
22 |
+
None,
|
23 |
+
lambda: client.chat(
|
24 |
+
chat_history=[],
|
25 |
+
message=prompt,
|
26 |
+
model=model, # Use model name directly from MODELS.csv
|
27 |
+
stream=True,
|
28 |
+
temperature=0.7
|
29 |
+
)
|
30 |
+
)
|
31 |
+
|
32 |
+
# Process each chunk
|
33 |
+
for event in response:
|
34 |
+
if hasattr(event, 'text') and event.text:
|
35 |
+
# Use asyncio.sleep to prevent blocking
|
36 |
+
await asyncio.sleep(0)
|
37 |
+
yield event.text
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
raise Exception(f"Error with Cohere API: {str(e)}")
|
41 |
+
|
42 |
+
async def run_model(api_key: str, model: str, prompt: str) -> str:
|
43 |
+
"""
|
44 |
+
Run the Cohere model with the provided API key and prompt (non-streaming).
|
45 |
+
|
46 |
+
Args:
|
47 |
+
api_key: The API key to use for this request
|
48 |
+
model: The model name to use
|
49 |
+
prompt: The user's input prompt
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
str: The generated response
|
53 |
+
"""
|
54 |
+
response = ""
|
55 |
+
async for chunk in run_model_stream(api_key, model, prompt):
|
56 |
+
response += chunk
|
57 |
+
return response
|
models/gemini.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import google.generativeai as genai
|
2 |
+
import asyncio
|
3 |
+
from typing import AsyncGenerator
|
4 |
+
|
5 |
+
async def run_model_stream(api_key: str, model: str, prompt: str):
|
6 |
+
"""
|
7 |
+
Run the Gemini model with streaming response.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
api_key: The API key to use for this request
|
11 |
+
model: The model name to use
|
12 |
+
prompt: The user's input prompt
|
13 |
+
|
14 |
+
Yields:
|
15 |
+
str: Chunks of the generated response
|
16 |
+
"""
|
17 |
+
try:
|
18 |
+
# Configure the Gemini API
|
19 |
+
genai.configure(api_key=api_key)
|
20 |
+
|
21 |
+
# Initialize the model with name from MODELS.csv
|
22 |
+
model_instance = genai.GenerativeModel(model)
|
23 |
+
|
24 |
+
# Start the streaming response using async executor
|
25 |
+
response = await asyncio.get_event_loop().run_in_executor(
|
26 |
+
None,
|
27 |
+
lambda: model_instance.generate_content(
|
28 |
+
prompt,
|
29 |
+
stream=True
|
30 |
+
)
|
31 |
+
)
|
32 |
+
|
33 |
+
# Process chunks with async handling
|
34 |
+
for chunk in response:
|
35 |
+
if chunk.text:
|
36 |
+
# Use asyncio.sleep to prevent blocking
|
37 |
+
await asyncio.sleep(0)
|
38 |
+
yield chunk.text
|
39 |
+
|
40 |
+
except Exception as e:
|
41 |
+
raise Exception(f"Error with Gemini API: {str(e)}")
|
42 |
+
|
43 |
+
async def run_model(api_key: str, model: str, prompt: str) -> str:
|
44 |
+
"""
|
45 |
+
Run the Gemini model with the provided API key and prompt (non-streaming).
|
46 |
+
|
47 |
+
Args:
|
48 |
+
api_key: The API key to use for this request
|
49 |
+
model: The model name to use
|
50 |
+
prompt: The user's input prompt
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: The generated response
|
54 |
+
"""
|
55 |
+
response = ""
|
56 |
+
async for chunk in run_model_stream(api_key, model, prompt):
|
57 |
+
response += chunk
|
58 |
+
return response
|
models/groq.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import AsyncOpenAI
|
2 |
+
import httpx
|
3 |
+
|
4 |
+
async def run_model_stream(api_key: str, model: str, prompt: str):
|
5 |
+
"""
|
6 |
+
Run the Groq model with streaming response.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
api_key: The API key to use for this request
|
10 |
+
model: The model name to use
|
11 |
+
prompt: The user's input prompt
|
12 |
+
|
13 |
+
Yields:
|
14 |
+
str: Chunks of the generated response
|
15 |
+
"""
|
16 |
+
try:
|
17 |
+
# Initialize AsyncOpenAI client with specific configuration
|
18 |
+
client = AsyncOpenAI(
|
19 |
+
api_key=api_key,
|
20 |
+
base_url="https://api.groq.com/openai/v1",
|
21 |
+
http_client=httpx.AsyncClient(verify=True) # Async client
|
22 |
+
)
|
23 |
+
|
24 |
+
completion = await client.chat.completions.create(
|
25 |
+
messages=[
|
26 |
+
{
|
27 |
+
"role": "user",
|
28 |
+
"content": prompt
|
29 |
+
}
|
30 |
+
],
|
31 |
+
model=model,
|
32 |
+
stream=True
|
33 |
+
)
|
34 |
+
|
35 |
+
async for chunk in completion:
|
36 |
+
if chunk.choices[0].delta.content is not None:
|
37 |
+
yield chunk.choices[0].delta.content
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
raise Exception(f"Error with Groq API: {str(e)}")
|
41 |
+
|
42 |
+
async def run_model(api_key: str, model: str, prompt: str) -> str:
|
43 |
+
"""
|
44 |
+
Run the Groq model with the provided API key and prompt (non-streaming).
|
45 |
+
|
46 |
+
Args:
|
47 |
+
api_key: The API key to use for this request
|
48 |
+
model: The model name to use
|
49 |
+
prompt: The user's input prompt
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
str: The generated response
|
53 |
+
"""
|
54 |
+
try:
|
55 |
+
# Initialize AsyncOpenAI client with specific configuration
|
56 |
+
client = AsyncOpenAI(
|
57 |
+
api_key=api_key,
|
58 |
+
base_url="https://api.groq.com/openai/v1",
|
59 |
+
http_client=httpx.AsyncClient(verify=True) # Async client
|
60 |
+
)
|
61 |
+
|
62 |
+
chat_completion = await client.chat.completions.create(
|
63 |
+
messages=[
|
64 |
+
{
|
65 |
+
"role": "user",
|
66 |
+
"content": prompt
|
67 |
+
}
|
68 |
+
],
|
69 |
+
model=model
|
70 |
+
)
|
71 |
+
|
72 |
+
return chat_completion.choices[0].message.content
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
raise Exception(f"Error with Groq API: {str(e)}")
|
models/mistral.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from mistralai import Mistral
|
3 |
+
|
4 |
+
api_key = "gZkpXZUIvz1ryxpyODGmsommNbryox2s"
|
5 |
+
model = "mistral-large-latest"
|
6 |
+
|
7 |
+
client = Mistral(api_key=api_key)
|
8 |
+
|
9 |
+
stream_response = client.chat.stream(
|
10 |
+
model = model,
|
11 |
+
messages = [
|
12 |
+
{
|
13 |
+
"role": "user",
|
14 |
+
"content": "What is the best French cheese?",
|
15 |
+
},
|
16 |
+
]
|
17 |
+
)
|
18 |
+
|
19 |
+
for chunk in stream_response:
|
20 |
+
print(chunk.data.choices[0].delta.content)
|
models/sambanova.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import httpx
|
2 |
+
import json
|
3 |
+
from typing import AsyncGenerator
|
4 |
+
|
5 |
+
async def run_model_stream(api_key: str, model: str, prompt: str):
|
6 |
+
"""
|
7 |
+
Run the SambaNova model with streaming response.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
api_key: The API key to use for this request
|
11 |
+
model: The model name to use
|
12 |
+
prompt: The user's input prompt
|
13 |
+
|
14 |
+
Yields:
|
15 |
+
str: Chunks of the generated response
|
16 |
+
"""
|
17 |
+
try:
|
18 |
+
# Configure HTTP client with appropriate headers and SSL settings
|
19 |
+
headers = {
|
20 |
+
"Authorization": f"Bearer {api_key}",
|
21 |
+
"Content-Type": "application/json",
|
22 |
+
"Accept": "text/event-stream"
|
23 |
+
}
|
24 |
+
|
25 |
+
async with httpx.AsyncClient(
|
26 |
+
base_url="https://api.sambanova.ai/v1",
|
27 |
+
headers=headers,
|
28 |
+
verify=True,
|
29 |
+
timeout=httpx.Timeout(60.0, read=300.0)
|
30 |
+
) as client:
|
31 |
+
# Make streaming request
|
32 |
+
async with client.stream(
|
33 |
+
"POST",
|
34 |
+
"/chat/completions",
|
35 |
+
json={
|
36 |
+
"model": model, # Use model name directly from MODELS.csv
|
37 |
+
"messages": [
|
38 |
+
{
|
39 |
+
"role": "user",
|
40 |
+
"content": prompt
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"stream": True,
|
44 |
+
"temperature": 0.7,
|
45 |
+
"max_tokens": 2048
|
46 |
+
}
|
47 |
+
) as response:
|
48 |
+
response.raise_for_status()
|
49 |
+
|
50 |
+
async for line in response.aiter_lines():
|
51 |
+
line = line.strip()
|
52 |
+
if not line:
|
53 |
+
continue
|
54 |
+
|
55 |
+
if line.startswith("data: "):
|
56 |
+
data = line[6:].strip()
|
57 |
+
if data == "[DONE]":
|
58 |
+
break
|
59 |
+
|
60 |
+
try:
|
61 |
+
chunk_data = json.loads(data)
|
62 |
+
if chunk_data.get("choices") and chunk_data["choices"][0].get("delta"):
|
63 |
+
content = chunk_data["choices"][0]["delta"].get("content")
|
64 |
+
if content:
|
65 |
+
yield content
|
66 |
+
except Exception as e:
|
67 |
+
print(f"Error parsing chunk: {e}")
|
68 |
+
continue
|
69 |
+
|
70 |
+
except httpx.HTTPError as e:
|
71 |
+
raise Exception(f"HTTP error with SambaNova API: {str(e)}")
|
72 |
+
except Exception as e:
|
73 |
+
raise Exception(f"Error with SambaNova API: {str(e)}")
|
74 |
+
|
75 |
+
async def run_model(api_key: str, model: str, prompt: str) -> str:
|
76 |
+
"""
|
77 |
+
Run the SambaNova model with the provided API key and prompt (non-streaming).
|
78 |
+
|
79 |
+
Args:
|
80 |
+
api_key: The API key to use for this request
|
81 |
+
model: The model name to use
|
82 |
+
prompt: The user's input prompt
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
str: The generated response
|
86 |
+
"""
|
87 |
+
response_text = ""
|
88 |
+
async for chunk in run_model_stream(api_key, model, prompt):
|
89 |
+
response_text += chunk
|
90 |
+
return response_text
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.110.0
|
2 |
+
uvicorn==0.27.1
|
3 |
+
python-multipart==0.0.9
|
4 |
+
pandas==2.2.1
|
5 |
+
openai==1.13.3
|
6 |
+
google-generativeai==0.3.2
|
7 |
+
cohere==4.51
|
8 |
+
requests==2.31.0
|
utils.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from typing import Dict, List, Optional
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
|
6 |
+
# Configure logging
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class KeyRotator:
|
11 |
+
def __init__(self):
|
12 |
+
self.current_indices: Dict[str, int] = {}
|
13 |
+
self.api_keys: Dict[str, List[str]] = {}
|
14 |
+
self.models_map: Dict[str, str] = {}
|
15 |
+
self._load_api_keys()
|
16 |
+
self._load_models()
|
17 |
+
|
18 |
+
def _load_api_keys(self) -> None:
|
19 |
+
"""Load API keys from CSV file and initialize rotation indices."""
|
20 |
+
try:
|
21 |
+
keys_df = pd.read_csv('apikeys.csv')
|
22 |
+
# Map model provider names to API key provider names
|
23 |
+
provider_map = {
|
24 |
+
"GROQ MODELS": "GROQ API ", # Note the trailing space
|
25 |
+
"COHERE": "COHERE AI",
|
26 |
+
"SambaNova": "Samba api key",
|
27 |
+
"GEMINI": "GEMINI API"
|
28 |
+
}
|
29 |
+
|
30 |
+
logger.info("Available API key columns: %s", keys_df.columns.tolist())
|
31 |
+
|
32 |
+
for model_provider, api_provider in provider_map.items():
|
33 |
+
logger.info(f"Processing provider mapping: {model_provider} -> {api_provider}")
|
34 |
+
if api_provider in keys_df.columns:
|
35 |
+
valid_keys = [key for key in keys_df[api_provider].dropna()]
|
36 |
+
logger.info(f"Found {len(valid_keys)} valid keys for {model_provider}")
|
37 |
+
if valid_keys:
|
38 |
+
self.api_keys[model_provider] = valid_keys
|
39 |
+
self.current_indices[model_provider] = 0
|
40 |
+
else:
|
41 |
+
logger.warning(f"API provider column {api_provider} not found in CSV")
|
42 |
+
logger.warning(f"Available columns: {keys_df.columns.tolist()}")
|
43 |
+
except Exception as e:
|
44 |
+
logger.error(f"Error loading API keys: {str(e)}")
|
45 |
+
raise Exception(f"Error loading API keys: {str(e)}")
|
46 |
+
|
47 |
+
def _load_models(self) -> None:
|
48 |
+
"""Load models from CSV file and create model-to-provider mapping."""
|
49 |
+
try:
|
50 |
+
models_df = pd.read_csv('MODELS.csv')
|
51 |
+
logger.info("Available model providers: %s", models_df.columns.tolist())
|
52 |
+
for column in models_df.columns:
|
53 |
+
provider = column.strip()
|
54 |
+
models = [model for model in models_df[column].dropna()]
|
55 |
+
logger.info(f"Loading {len(models)} models for provider {provider}")
|
56 |
+
for model in models:
|
57 |
+
self.models_map[model.strip()] = provider
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error loading models: {str(e)}")
|
60 |
+
raise Exception(f"Error loading models: {str(e)}")
|
61 |
+
|
62 |
+
def get_next_key(self, provider: str) -> str:
|
63 |
+
"""Get the next API key for the specified provider using rotation."""
|
64 |
+
if provider not in self.api_keys:
|
65 |
+
logger.error(f"No API keys found for provider: {provider}")
|
66 |
+
logger.error(f"Available providers: {list(self.api_keys.keys())}")
|
67 |
+
raise ValueError(f"No API keys found for provider: {provider}")
|
68 |
+
|
69 |
+
keys = self.api_keys[provider]
|
70 |
+
current_idx = self.current_indices[provider]
|
71 |
+
total_keys = len(keys)
|
72 |
+
|
73 |
+
# Get current key and update index for next time
|
74 |
+
key = keys[current_idx]
|
75 |
+
next_idx = (current_idx + 1) % total_keys
|
76 |
+
self.current_indices[provider] = next_idx
|
77 |
+
|
78 |
+
# Log key rotation info with more details
|
79 |
+
key_preview = f"{key[:4]}...{key[-4:]}"
|
80 |
+
logger.info(f"### API Key Rotation for {provider} ###")
|
81 |
+
logger.info(f"Using key #{current_idx + 1}/{total_keys} ({key_preview})")
|
82 |
+
logger.info(f"Next request will use key #{next_idx + 1}/{total_keys}")
|
83 |
+
logger.info(f"#################################")
|
84 |
+
|
85 |
+
return key
|
86 |
+
|
87 |
+
def get_provider_for_model(self, model: str) -> str:
|
88 |
+
"""Get the provider name for a given model."""
|
89 |
+
model = model.strip()
|
90 |
+
if model not in self.models_map:
|
91 |
+
raise ValueError(f"Model not found: {model}")
|
92 |
+
return self.models_map[model]
|
93 |
+
|
94 |
+
def get_available_models(self) -> List[Dict[str, str]]:
|
95 |
+
"""Get list of all available models with their providers."""
|
96 |
+
return [{"model": model, "provider": provider}
|
97 |
+
for model, provider in self.models_map.items()]
|
98 |
+
|
99 |
+
# Global instance
|
100 |
+
key_rotator = KeyRotator()
|