Spaces:
Running
Running
SUBHRAJIT MOHANTY
commited on
Commit
Β·
c489cf0
1
Parent(s):
c3ff739
security Updated
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, Request, UploadFile, File
|
2 |
from fastapi.responses import StreamingResponse
|
|
|
3 |
from pydantic import BaseModel, Field
|
4 |
from typing import List, Optional, Dict, Any, AsyncGenerator
|
5 |
import asyncio
|
@@ -11,6 +12,9 @@ from contextlib import asynccontextmanager
|
|
11 |
import tempfile
|
12 |
import shutil
|
13 |
import random
|
|
|
|
|
|
|
14 |
|
15 |
# Third-party imports
|
16 |
from openai import OpenAI, AsyncOpenAI
|
@@ -75,11 +79,170 @@ class Config:
|
|
75 |
TOP_K = int(os.getenv("TOP_K", "10"))
|
76 |
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.1"))
|
77 |
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Models
|
80 |
OPENROUTER_MODELS = ["deepseek/deepseek-chat-v3-0324:free", "deepseek/deepseek-r1-0528:free", "qwen/qwen3-235b-a22b:free", "google/gemini-2.0-flash-exp:free"]
|
81 |
GROQ_MODELS = ["llama-3.3-70b-versatile"]
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
class DynamicOpenAIService:
|
84 |
"""Service for dynamic OpenAI provider selection"""
|
85 |
|
@@ -656,6 +819,19 @@ async def lifespan(app: FastAPI):
|
|
656 |
|
657 |
print("π All services initialized successfully!")
|
658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
659 |
yield
|
660 |
|
661 |
# Shutdown
|
@@ -681,10 +857,15 @@ app = FastAPI(
|
|
681 |
|
682 |
@app.get("/")
|
683 |
async def root():
|
684 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
685 |
|
686 |
@app.get("/health")
|
687 |
-
async def health_check():
|
688 |
"""Health check endpoint"""
|
689 |
try:
|
690 |
# Test Qdrant connection
|
@@ -755,7 +936,7 @@ async def health_check():
|
|
755 |
}
|
756 |
|
757 |
@app.post("/v1/chat/completions")
|
758 |
-
async def chat_completions(request: ChatCompletionRequest):
|
759 |
"""OpenAI-compatible chat completions endpoint with enhanced RAG and dynamic provider selection"""
|
760 |
|
761 |
if not app_state.openai_service:
|
@@ -912,7 +1093,11 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
|
|
912 |
|
913 |
# Document management endpoints
|
914 |
@app.post("/v1/documents/upload")
|
915 |
-
async def upload_document(
|
|
|
|
|
|
|
|
|
916 |
"""Upload a PDF document"""
|
917 |
try:
|
918 |
if not app_state.document_manager:
|
@@ -966,7 +1151,7 @@ async def upload_document(file: UploadFile = File(...), metadata: str = None):
|
|
966 |
raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}")
|
967 |
|
968 |
@app.post("/v1/documents/search")
|
969 |
-
async def search_documents(request: DocumentSearchRequest):
|
970 |
"""Search for documents"""
|
971 |
try:
|
972 |
if not app_state.document_manager:
|
@@ -989,7 +1174,7 @@ async def search_documents(request: DocumentSearchRequest):
|
|
989 |
raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}")
|
990 |
|
991 |
@app.get("/v1/documents/list")
|
992 |
-
async def list_documents():
|
993 |
"""List all documents"""
|
994 |
try:
|
995 |
if not app_state.document_manager:
|
@@ -1007,7 +1192,7 @@ async def list_documents():
|
|
1007 |
raise HTTPException(status_code=500, detail=f"Error listing documents: {str(e)}")
|
1008 |
|
1009 |
@app.delete("/v1/documents/{document_id}")
|
1010 |
-
async def delete_document(document_id: str):
|
1011 |
"""Delete a document"""
|
1012 |
try:
|
1013 |
if not app_state.document_manager:
|
@@ -1028,7 +1213,7 @@ async def delete_document(document_id: str):
|
|
1028 |
|
1029 |
# Legacy compatibility endpoints
|
1030 |
@app.post("/v1/embeddings/add")
|
1031 |
-
async def add_document_legacy(content: str, metadata: Optional[Dict] = None):
|
1032 |
"""Legacy endpoint for adding documents (text content)"""
|
1033 |
try:
|
1034 |
if not app_state.embedding_service or not app_state.qdrant_client:
|
@@ -1059,7 +1244,7 @@ async def add_document_legacy(content: str, metadata: Optional[Dict] = None):
|
|
1059 |
raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}")
|
1060 |
|
1061 |
@app.get("/v1/collections/info")
|
1062 |
-
async def get_collection_info():
|
1063 |
"""Get information about the collection"""
|
1064 |
try:
|
1065 |
if app_state.qdrant_client is None:
|
@@ -1079,7 +1264,7 @@ async def get_collection_info():
|
|
1079 |
|
1080 |
# New endpoint to get available providers and models
|
1081 |
@app.get("/v1/providers")
|
1082 |
-
async def get_providers():
|
1083 |
"""Get available providers and their models"""
|
1084 |
try:
|
1085 |
if not app_state.openai_service:
|
@@ -1117,6 +1302,100 @@ async def get_providers():
|
|
1117 |
except Exception as e:
|
1118 |
raise HTTPException(status_code=500, detail=f"Error getting providers: {str(e)}")
|
1119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1120 |
if __name__ == "__main__":
|
1121 |
import uvicorn
|
1122 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Depends, status
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
4 |
from pydantic import BaseModel, Field
|
5 |
from typing import List, Optional, Dict, Any, AsyncGenerator
|
6 |
import asyncio
|
|
|
12 |
import tempfile
|
13 |
import shutil
|
14 |
import random
|
15 |
+
import hashlib
|
16 |
+
import secrets
|
17 |
+
from functools import wraps
|
18 |
|
19 |
# Third-party imports
|
20 |
from openai import OpenAI, AsyncOpenAI
|
|
|
79 |
TOP_K = int(os.getenv("TOP_K", "10"))
|
80 |
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.1"))
|
81 |
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
82 |
+
|
83 |
+
# Security Configuration
|
84 |
+
API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else []
|
85 |
+
MASTER_KEY = os.getenv("MASTER_KEY", "")
|
86 |
+
ENABLE_SECURITY = os.getenv("ENABLE_SECURITY", "true").lower() == "true"
|
87 |
+
RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60"))
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def generate_api_key(cls) -> str:
|
91 |
+
"""Generate a new API key"""
|
92 |
+
return f"sk-{secrets.token_urlsafe(32)}"
|
93 |
+
|
94 |
+
@classmethod
|
95 |
+
def validate_api_key(cls, api_key: str) -> bool:
|
96 |
+
"""Validate API key"""
|
97 |
+
if not cls.ENABLE_SECURITY:
|
98 |
+
return True
|
99 |
+
|
100 |
+
if not api_key:
|
101 |
+
return False
|
102 |
+
|
103 |
+
# Check master key
|
104 |
+
if cls.MASTER_KEY and api_key == cls.MASTER_KEY:
|
105 |
+
return True
|
106 |
+
|
107 |
+
# Check configured API keys
|
108 |
+
if cls.API_KEYS and api_key in cls.API_KEYS:
|
109 |
+
return True
|
110 |
+
|
111 |
+
return False
|
112 |
|
113 |
# Models
|
114 |
OPENROUTER_MODELS = ["deepseek/deepseek-chat-v3-0324:free", "deepseek/deepseek-r1-0528:free", "qwen/qwen3-235b-a22b:free", "google/gemini-2.0-flash-exp:free"]
|
115 |
GROQ_MODELS = ["llama-3.3-70b-versatile"]
|
116 |
|
117 |
+
# Security Models
|
118 |
+
class APIKeyRequest(BaseModel):
|
119 |
+
description: Optional[str] = Field(None, description="Description for the API key")
|
120 |
+
|
121 |
+
class APIKeyResponse(BaseModel):
|
122 |
+
api_key: str
|
123 |
+
description: Optional[str] = None
|
124 |
+
created_at: str
|
125 |
+
status: str = "active"
|
126 |
+
|
127 |
+
class SecurityInfo(BaseModel):
|
128 |
+
security_enabled: bool
|
129 |
+
rate_limit_per_minute: int
|
130 |
+
has_master_key: bool
|
131 |
+
configured_keys_count: int
|
132 |
+
|
133 |
+
# Rate Limiting
|
134 |
+
class RateLimiter:
|
135 |
+
def __init__(self):
|
136 |
+
self.requests = {}
|
137 |
+
self.blocked_ips = set()
|
138 |
+
|
139 |
+
def is_allowed(self, identifier: str, limit_per_minute: int = Config.RATE_LIMIT_PER_MINUTE) -> bool:
|
140 |
+
"""Check if request is allowed based on rate limit"""
|
141 |
+
if not Config.ENABLE_SECURITY:
|
142 |
+
return True
|
143 |
+
|
144 |
+
if identifier in self.blocked_ips:
|
145 |
+
return False
|
146 |
+
|
147 |
+
now = datetime.now()
|
148 |
+
minute_key = now.strftime("%Y-%m-%d %H:%M")
|
149 |
+
|
150 |
+
if identifier not in self.requests:
|
151 |
+
self.requests[identifier] = {}
|
152 |
+
|
153 |
+
if minute_key not in self.requests[identifier]:
|
154 |
+
self.requests[identifier][minute_key] = 0
|
155 |
+
|
156 |
+
# Clean old entries (keep only last 2 minutes)
|
157 |
+
keys_to_remove = []
|
158 |
+
for key in self.requests[identifier]:
|
159 |
+
try:
|
160 |
+
key_time = datetime.strptime(key, "%Y-%m-%d %H:%M")
|
161 |
+
if (now - key_time).total_seconds() > 120: # 2 minutes
|
162 |
+
keys_to_remove.append(key)
|
163 |
+
except ValueError:
|
164 |
+
keys_to_remove.append(key)
|
165 |
+
|
166 |
+
for key in keys_to_remove:
|
167 |
+
del self.requests[identifier][key]
|
168 |
+
|
169 |
+
# Check current minute limit
|
170 |
+
current_requests = self.requests[identifier].get(minute_key, 0)
|
171 |
+
if current_requests >= limit_per_minute:
|
172 |
+
return False
|
173 |
+
|
174 |
+
self.requests[identifier][minute_key] = current_requests + 1
|
175 |
+
return True
|
176 |
+
|
177 |
+
def block_ip(self, ip: str):
|
178 |
+
"""Block an IP address"""
|
179 |
+
self.blocked_ips.add(ip)
|
180 |
+
|
181 |
+
def unblock_ip(self, ip: str):
|
182 |
+
"""Unblock an IP address"""
|
183 |
+
self.blocked_ips.discard(ip)
|
184 |
+
|
185 |
+
# Security Dependencies
|
186 |
+
security = HTTPBearer(auto_error=False)
|
187 |
+
rate_limiter = RateLimiter()
|
188 |
+
|
189 |
+
async def verify_api_key(
|
190 |
+
request: Request,
|
191 |
+
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
192 |
+
) -> str:
|
193 |
+
"""Verify API key from Authorization header"""
|
194 |
+
if not Config.ENABLE_SECURITY:
|
195 |
+
return "security_disabled"
|
196 |
+
|
197 |
+
# Get client IP
|
198 |
+
client_ip = request.client.host
|
199 |
+
|
200 |
+
# Check rate limit
|
201 |
+
if not rate_limiter.is_allowed(client_ip):
|
202 |
+
raise HTTPException(
|
203 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
204 |
+
detail="Rate limit exceeded"
|
205 |
+
)
|
206 |
+
|
207 |
+
# Check API key
|
208 |
+
if not credentials:
|
209 |
+
raise HTTPException(
|
210 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
211 |
+
detail="API key required. Please provide a valid API key in the Authorization header as 'Bearer <your-api-key>'"
|
212 |
+
)
|
213 |
+
|
214 |
+
api_key = credentials.credentials
|
215 |
+
if not Config.validate_api_key(api_key):
|
216 |
+
raise HTTPException(
|
217 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
218 |
+
detail="Invalid API key"
|
219 |
+
)
|
220 |
+
|
221 |
+
return api_key
|
222 |
+
|
223 |
+
async def verify_master_key(
|
224 |
+
request: Request,
|
225 |
+
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
226 |
+
) -> str:
|
227 |
+
"""Verify master key for admin operations"""
|
228 |
+
if not Config.ENABLE_SECURITY:
|
229 |
+
return "security_disabled"
|
230 |
+
|
231 |
+
if not credentials:
|
232 |
+
raise HTTPException(
|
233 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
234 |
+
detail="Master key required for admin operations"
|
235 |
+
)
|
236 |
+
|
237 |
+
api_key = credentials.credentials
|
238 |
+
if not Config.MASTER_KEY or api_key != Config.MASTER_KEY:
|
239 |
+
raise HTTPException(
|
240 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
241 |
+
detail="Invalid master key"
|
242 |
+
)
|
243 |
+
|
244 |
+
return api_key
|
245 |
+
|
246 |
class DynamicOpenAIService:
|
247 |
"""Service for dynamic OpenAI provider selection"""
|
248 |
|
|
|
819 |
|
820 |
print("π All services initialized successfully!")
|
821 |
|
822 |
+
# Print security information
|
823 |
+
if Config.ENABLE_SECURITY:
|
824 |
+
print("\nπ Security Configuration:")
|
825 |
+
print(f" Security: ENABLED")
|
826 |
+
print(f" Rate Limit: {Config.RATE_LIMIT_PER_MINUTE} requests/minute")
|
827 |
+
print(f" Master Key: {'β Configured' if Config.MASTER_KEY else 'β Not configured'}")
|
828 |
+
print(f" API Keys: {len([k for k in Config.API_KEYS if k.strip()])} configured")
|
829 |
+
if not Config.MASTER_KEY and not Config.API_KEYS:
|
830 |
+
print(" β οΈ WARNING: No API keys configured! Set MASTER_KEY or API_KEYS environment variable.")
|
831 |
+
else:
|
832 |
+
print("\nπ Security: DISABLED")
|
833 |
+
print(" All endpoints are publicly accessible")
|
834 |
+
|
835 |
yield
|
836 |
|
837 |
# Shutdown
|
|
|
857 |
|
858 |
@app.get("/")
|
859 |
async def root():
|
860 |
+
return {
|
861 |
+
"message": "Enhanced RAG API with Dynamic Provider Selection",
|
862 |
+
"status": "running",
|
863 |
+
"security_enabled": Config.ENABLE_SECURITY,
|
864 |
+
"version": "1.0.0"
|
865 |
+
}
|
866 |
|
867 |
@app.get("/health")
|
868 |
+
async def health_check(api_key: str = Depends(verify_api_key)):
|
869 |
"""Health check endpoint"""
|
870 |
try:
|
871 |
# Test Qdrant connection
|
|
|
936 |
}
|
937 |
|
938 |
@app.post("/v1/chat/completions")
|
939 |
+
async def chat_completions(request: ChatCompletionRequest, api_key: str = Depends(verify_api_key)):
|
940 |
"""OpenAI-compatible chat completions endpoint with enhanced RAG and dynamic provider selection"""
|
941 |
|
942 |
if not app_state.openai_service:
|
|
|
1093 |
|
1094 |
# Document management endpoints
|
1095 |
@app.post("/v1/documents/upload")
|
1096 |
+
async def upload_document(
|
1097 |
+
file: UploadFile = File(...),
|
1098 |
+
metadata: str = None,
|
1099 |
+
api_key: str = Depends(verify_api_key)
|
1100 |
+
):
|
1101 |
"""Upload a PDF document"""
|
1102 |
try:
|
1103 |
if not app_state.document_manager:
|
|
|
1151 |
raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}")
|
1152 |
|
1153 |
@app.post("/v1/documents/search")
|
1154 |
+
async def search_documents(request: DocumentSearchRequest, api_key: str = Depends(verify_api_key)):
|
1155 |
"""Search for documents"""
|
1156 |
try:
|
1157 |
if not app_state.document_manager:
|
|
|
1174 |
raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}")
|
1175 |
|
1176 |
@app.get("/v1/documents/list")
|
1177 |
+
async def list_documents(api_key: str = Depends(verify_api_key)):
|
1178 |
"""List all documents"""
|
1179 |
try:
|
1180 |
if not app_state.document_manager:
|
|
|
1192 |
raise HTTPException(status_code=500, detail=f"Error listing documents: {str(e)}")
|
1193 |
|
1194 |
@app.delete("/v1/documents/{document_id}")
|
1195 |
+
async def delete_document(document_id: str, api_key: str = Depends(verify_api_key)):
|
1196 |
"""Delete a document"""
|
1197 |
try:
|
1198 |
if not app_state.document_manager:
|
|
|
1213 |
|
1214 |
# Legacy compatibility endpoints
|
1215 |
@app.post("/v1/embeddings/add")
|
1216 |
+
async def add_document_legacy(content: str, metadata: Optional[Dict] = None, api_key: str = Depends(verify_api_key)):
|
1217 |
"""Legacy endpoint for adding documents (text content)"""
|
1218 |
try:
|
1219 |
if not app_state.embedding_service or not app_state.qdrant_client:
|
|
|
1244 |
raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}")
|
1245 |
|
1246 |
@app.get("/v1/collections/info")
|
1247 |
+
async def get_collection_info(api_key: str = Depends(verify_api_key)):
|
1248 |
"""Get information about the collection"""
|
1249 |
try:
|
1250 |
if app_state.qdrant_client is None:
|
|
|
1264 |
|
1265 |
# New endpoint to get available providers and models
|
1266 |
@app.get("/v1/providers")
|
1267 |
+
async def get_providers(api_key: str = Depends(verify_api_key)):
|
1268 |
"""Get available providers and their models"""
|
1269 |
try:
|
1270 |
if not app_state.openai_service:
|
|
|
1302 |
except Exception as e:
|
1303 |
raise HTTPException(status_code=500, detail=f"Error getting providers: {str(e)}")
|
1304 |
|
1305 |
+
# Security Management Endpoints
|
1306 |
+
@app.get("/v1/security/info")
|
1307 |
+
async def get_security_info() -> SecurityInfo:
|
1308 |
+
"""Get security configuration information (public endpoint)"""
|
1309 |
+
return SecurityInfo(
|
1310 |
+
security_enabled=Config.ENABLE_SECURITY,
|
1311 |
+
rate_limit_per_minute=Config.RATE_LIMIT_PER_MINUTE,
|
1312 |
+
has_master_key=bool(Config.MASTER_KEY),
|
1313 |
+
configured_keys_count=len([k for k in Config.API_KEYS if k.strip()])
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
@app.post("/v1/security/generate-key")
|
1317 |
+
async def generate_api_key(
|
1318 |
+
request: APIKeyRequest,
|
1319 |
+
master_key: str = Depends(verify_master_key)
|
1320 |
+
) -> APIKeyResponse:
|
1321 |
+
"""Generate a new API key (requires master key)"""
|
1322 |
+
try:
|
1323 |
+
new_key = Config.generate_api_key()
|
1324 |
+
|
1325 |
+
return APIKeyResponse(
|
1326 |
+
api_key=new_key,
|
1327 |
+
description=request.description,
|
1328 |
+
created_at=datetime.now().isoformat(),
|
1329 |
+
status="active"
|
1330 |
+
)
|
1331 |
+
except Exception as e:
|
1332 |
+
raise HTTPException(status_code=500, detail=f"Error generating API key: {str(e)}")
|
1333 |
+
|
1334 |
+
@app.post("/v1/security/validate-key")
|
1335 |
+
async def validate_api_key_endpoint(
|
1336 |
+
api_key: str = Depends(verify_api_key)
|
1337 |
+
) -> Dict[str, Any]:
|
1338 |
+
"""Validate an API key"""
|
1339 |
+
return {
|
1340 |
+
"valid": True,
|
1341 |
+
"key_type": "master" if api_key == Config.MASTER_KEY else "standard",
|
1342 |
+
"validated_at": datetime.now().isoformat()
|
1343 |
+
}
|
1344 |
+
|
1345 |
+
@app.get("/v1/security/rate-limit-status")
|
1346 |
+
async def get_rate_limit_status(
|
1347 |
+
request: Request,
|
1348 |
+
api_key: str = Depends(verify_api_key)
|
1349 |
+
) -> Dict[str, Any]:
|
1350 |
+
"""Get current rate limit status"""
|
1351 |
+
client_ip = request.client.host
|
1352 |
+
|
1353 |
+
# Get current minute requests
|
1354 |
+
now = datetime.now()
|
1355 |
+
minute_key = now.strftime("%Y-%m-%d %H:%M")
|
1356 |
+
|
1357 |
+
current_requests = 0
|
1358 |
+
if client_ip in rate_limiter.requests:
|
1359 |
+
current_requests = rate_limiter.requests[client_ip].get(minute_key, 0)
|
1360 |
+
|
1361 |
+
return {
|
1362 |
+
"client_ip": client_ip,
|
1363 |
+
"current_requests": current_requests,
|
1364 |
+
"limit_per_minute": Config.RATE_LIMIT_PER_MINUTE,
|
1365 |
+
"remaining_requests": max(0, Config.RATE_LIMIT_PER_MINUTE - current_requests),
|
1366 |
+
"reset_at": f"{minute_key}:00",
|
1367 |
+
"is_blocked": client_ip in rate_limiter.blocked_ips
|
1368 |
+
}
|
1369 |
+
|
1370 |
+
# Admin endpoints for IP management
|
1371 |
+
@app.post("/v1/admin/block-ip/{ip}")
|
1372 |
+
async def block_ip(
|
1373 |
+
ip: str,
|
1374 |
+
master_key: str = Depends(verify_master_key)
|
1375 |
+
) -> Dict[str, str]:
|
1376 |
+
"""Block an IP address (requires master key)"""
|
1377 |
+
rate_limiter.block_ip(ip)
|
1378 |
+
return {"message": f"IP {ip} has been blocked", "blocked_at": datetime.now().isoformat()}
|
1379 |
+
|
1380 |
+
@app.post("/v1/admin/unblock-ip/{ip}")
|
1381 |
+
async def unblock_ip(
|
1382 |
+
ip: str,
|
1383 |
+
master_key: str = Depends(verify_master_key)
|
1384 |
+
) -> Dict[str, str]:
|
1385 |
+
"""Unblock an IP address (requires master key)"""
|
1386 |
+
rate_limiter.unblock_ip(ip)
|
1387 |
+
return {"message": f"IP {ip} has been unblocked", "unblocked_at": datetime.now().isoformat()}
|
1388 |
+
|
1389 |
+
@app.get("/v1/admin/blocked-ips")
|
1390 |
+
async def get_blocked_ips(
|
1391 |
+
master_key: str = Depends(verify_master_key)
|
1392 |
+
) -> Dict[str, Any]:
|
1393 |
+
"""Get list of blocked IPs (requires master key)"""
|
1394 |
+
return {
|
1395 |
+
"blocked_ips": list(rate_limiter.blocked_ips),
|
1396 |
+
"count": len(rate_limiter.blocked_ips)
|
1397 |
+
}
|
1398 |
+
|
1399 |
if __name__ == "__main__":
|
1400 |
import uvicorn
|
1401 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|