|
import asyncio
|
|
import time
|
|
import uuid
|
|
import os
|
|
import json
|
|
from typing import Dict, List, Optional, Union, Any
|
|
from fastapi import FastAPI, HTTPException, Depends, Request, status, Body
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from pydantic import BaseModel, Field, EmailStr
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
from slowapi.util import get_remote_address
|
|
from slowapi.errors import RateLimitExceeded
|
|
import uvicorn
|
|
|
|
from db_helper import MongoDBHelper
|
|
from deepinfra_client import DeepInfraClient
|
|
from hf_utils import HuggingFaceSpaceHelper
|
|
|
|
|
|
hf_helper = HuggingFaceSpaceHelper()
|
|
|
|
|
|
if hf_helper.is_in_space:
|
|
hf_helper.install_dependencies([
|
|
"pymongo", "python-dotenv", "fastapi", "uvicorn", "slowapi",
|
|
"fake-useragent", "requests-ip-rotator", "pydantic[email]"
|
|
])
|
|
|
|
|
|
app = FastAPI(
|
|
title="PyScoutAI API",
|
|
description="An OpenAI-compatible API that provides access to DeepInfra models with enhanced features",
|
|
version="1.0.0"
|
|
)
|
|
|
|
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
try:
|
|
db = MongoDBHelper(hf_helper.get_mongodb_uri())
|
|
except Exception as e:
|
|
print(f"Warning: MongoDB connection failed: {e}")
|
|
print("API key authentication will not work!")
|
|
db = None
|
|
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: Optional[str] = None
|
|
name: Optional[str] = None
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
user: Optional[str] = None
|
|
|
|
class CompletionRequest(BaseModel):
|
|
model: str
|
|
prompt: Union[str, List[str]]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
user: Optional[str] = None
|
|
|
|
class UserCreate(BaseModel):
|
|
email: EmailStr
|
|
name: str
|
|
organization: Optional[str] = None
|
|
|
|
class APIKeyCreate(BaseModel):
|
|
name: str = "Default API Key"
|
|
user_id: str
|
|
|
|
class APIKeyResponse(BaseModel):
|
|
key: str
|
|
name: str
|
|
created_at: str
|
|
|
|
|
|
clients: Dict[str, DeepInfraClient] = {}
|
|
|
|
|
|
async def get_api_key(
|
|
request: Request,
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
|
) -> Optional[str]:
|
|
|
|
if credentials:
|
|
return credentials.credentials
|
|
|
|
|
|
if "Authorization" in request.headers:
|
|
auth = request.headers["Authorization"]
|
|
if auth.startswith("Bearer "):
|
|
return auth.replace("Bearer ", "")
|
|
|
|
if "x-api-key" in request.headers:
|
|
return request.headers["x-api-key"]
|
|
|
|
|
|
api_key = request.query_params.get("api_key")
|
|
if api_key:
|
|
return api_key
|
|
|
|
|
|
return None
|
|
|
|
|
|
async def get_user_info(api_key: Optional[str] = Depends(get_api_key)) -> Dict[str, Any]:
|
|
if not api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="API key is required",
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
|
|
|
|
if not db:
|
|
return {"user_id": "development", "key": api_key}
|
|
|
|
|
|
if not api_key.startswith("PyScoutAI-"):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key format",
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
|
|
|
|
user_info = db.validate_api_key(api_key)
|
|
if not user_info:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key",
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
|
|
|
|
rate_limit = db.check_rate_limit(api_key)
|
|
if not rate_limit["allowed"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail=rate_limit["reason"]
|
|
)
|
|
|
|
return user_info
|
|
|
|
|
|
def get_client(api_key: str) -> DeepInfraClient:
|
|
if api_key not in clients:
|
|
|
|
clients[api_key] = DeepInfraClient(
|
|
use_random_user_agent=True,
|
|
use_proxy_rotation=True,
|
|
use_ip_rotation=True
|
|
)
|
|
return clients[api_key]
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
metadata = hf_helper.get_hf_metadata()
|
|
return {
|
|
"message": "Welcome to PyScoutAI API",
|
|
"documentation": "/docs",
|
|
"environment": "Hugging Face Space" if hf_helper.is_in_space else "Local",
|
|
"endpoints": [
|
|
"/v1/models",
|
|
"/v1/chat/completions",
|
|
"/v1/completions"
|
|
],
|
|
**metadata
|
|
}
|
|
|
|
@app.get("/v1/models")
|
|
@limiter.limit("20/minute")
|
|
async def list_models(
|
|
request: Request,
|
|
user_info: Dict[str, Any] = Depends(get_user_info)
|
|
):
|
|
api_key = user_info["key"]
|
|
client = get_client(api_key)
|
|
try:
|
|
models = await asyncio.to_thread(client.models.list)
|
|
|
|
if db:
|
|
db.log_api_usage(api_key, "/v1/models", 0)
|
|
return models
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}")
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@limiter.limit("60/minute")
|
|
async def create_chat_completion(
|
|
request: Request,
|
|
body: ChatCompletionRequest,
|
|
user_info: Dict[str, Any] = Depends(get_user_info)
|
|
):
|
|
api_key = user_info["key"]
|
|
client = get_client(api_key)
|
|
|
|
try:
|
|
|
|
messages = [{"role": msg.role, "content": msg.content} for msg in body.messages if msg.content is not None]
|
|
|
|
kwargs = {
|
|
"model": body.model,
|
|
"temperature": body.temperature,
|
|
"max_tokens": body.max_tokens,
|
|
"stream": body.stream,
|
|
"top_p": body.top_p,
|
|
"presence_penalty": body.presence_penalty,
|
|
"frequency_penalty": body.frequency_penalty,
|
|
}
|
|
|
|
if body.stream:
|
|
async def generate_stream():
|
|
response_stream = await asyncio.to_thread(
|
|
client.chat.create,
|
|
messages=messages,
|
|
**kwargs
|
|
)
|
|
|
|
total_tokens = 0
|
|
for chunk in response_stream:
|
|
|
|
if 'usage' in chunk and chunk['usage']:
|
|
total_tokens += chunk['usage'].get('total_tokens', 0)
|
|
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
|
|
|
|
if db:
|
|
db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate_stream(),
|
|
media_type="text/event-stream"
|
|
)
|
|
else:
|
|
response = await asyncio.to_thread(
|
|
client.chat.create,
|
|
messages=messages,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
if db and 'usage' in response:
|
|
total_tokens = response['usage'].get('total_tokens', 0)
|
|
db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
|
|
|
|
@app.post("/v1/completions")
|
|
@limiter.limit("60/minute")
|
|
async def create_completion(
|
|
request: Request,
|
|
body: CompletionRequest,
|
|
user_info: Dict[str, Any] = Depends(get_user_info)
|
|
):
|
|
api_key = user_info["key"]
|
|
client = get_client(api_key)
|
|
|
|
try:
|
|
|
|
prompt = body.prompt
|
|
if isinstance(prompt, list):
|
|
prompt = prompt[0]
|
|
|
|
kwargs = {
|
|
"model": body.model,
|
|
"temperature": body.temperature,
|
|
"max_tokens": body.max_tokens,
|
|
"stream": body.stream,
|
|
"top_p": body.top_p,
|
|
"presence_penalty": body.presence_penalty,
|
|
"frequency_penalty": body.frequency_penalty,
|
|
}
|
|
|
|
if body.stream:
|
|
async def generate_stream():
|
|
response_stream = await asyncio.to_thread(
|
|
client.completions.create,
|
|
prompt=prompt,
|
|
**kwargs
|
|
)
|
|
|
|
total_tokens = 0
|
|
for chunk in response_stream:
|
|
if 'usage' in chunk and chunk['usage']:
|
|
total_tokens += chunk['usage'].get('total_tokens', 0)
|
|
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
|
|
|
|
if db:
|
|
db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate_stream(),
|
|
media_type="text/event-stream"
|
|
)
|
|
else:
|
|
response = await asyncio.to_thread(
|
|
client.completions.create,
|
|
prompt=prompt,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
if db and 'usage' in response:
|
|
total_tokens = response['usage'].get('total_tokens', 0)
|
|
db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}")
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
status_info = {"api": "ok"}
|
|
|
|
|
|
if db:
|
|
try:
|
|
|
|
db.api_keys_collection.find_one({})
|
|
status_info["database"] = "ok"
|
|
except Exception as e:
|
|
status_info["database"] = f"error: {str(e)}"
|
|
else:
|
|
status_info["database"] = "not configured"
|
|
|
|
|
|
if hf_helper.is_in_space:
|
|
status_info["environment"] = "Hugging Face Space"
|
|
status_info["space_name"] = hf_helper.space_name
|
|
else:
|
|
status_info["environment"] = "Local"
|
|
|
|
return status_info
|
|
|
|
|
|
@app.post("/v1/api_keys", response_model=APIKeyResponse)
|
|
async def create_api_key(body: APIKeyCreate):
|
|
if not db:
|
|
raise HTTPException(status_code=500, detail="Database not configured")
|
|
|
|
try:
|
|
api_key = db.generate_api_key(body.user_id, body.name)
|
|
key_data = db.validate_api_key(api_key)
|
|
return {
|
|
"key": api_key,
|
|
"name": key_data["name"],
|
|
"created_at": key_data["created_at"].isoformat()
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error creating API key: {str(e)}")
|
|
|
|
@app.get("/v1/api_keys")
|
|
async def list_api_keys(user_id: str):
|
|
if not db:
|
|
raise HTTPException(status_code=500, detail="Database not configured")
|
|
|
|
keys = db.get_user_api_keys(user_id)
|
|
for key in keys:
|
|
if "created_at" in key:
|
|
key["created_at"] = key["created_at"].isoformat()
|
|
if "last_used" in key and key["last_used"]:
|
|
key["last_used"] = key["last_used"].isoformat()
|
|
|
|
return {"keys": keys}
|
|
|
|
@app.post("/v1/api_keys/revoke")
|
|
async def revoke_api_key(api_key: str):
|
|
if not db:
|
|
raise HTTPException(status_code=500, detail="Database not configured")
|
|
|
|
success = db.revoke_api_key(api_key)
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="API key not found")
|
|
|
|
return {"message": "API key revoked successfully"}
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def cleanup_clients():
|
|
for client in clients.values():
|
|
try:
|
|
if hasattr(client, 'ip_rotator') and client.ip_rotator:
|
|
client.ip_rotator.shutdown()
|
|
except:
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
|
|
host = hf_helper.get_hostname()
|
|
port = 8000
|
|
|
|
print(f"Starting PyScoutAI API on http://{host}:{port}")
|
|
print(f"Environment: {'Hugging Face Space' if hf_helper.is_in_space else 'Local'}")
|
|
|
|
uvicorn.run("pyscout_api:app", host=host, port=port, reload=not hf_helper.is_in_space)
|
|
|