SUBHRAJIT MOHANTY commited on
Commit
c489cf0
Β·
1 Parent(s): c3ff739

security Updated

Browse files
Files changed (1) hide show
  1. app.py +290 -11
app.py CHANGED
@@ -1,5 +1,6 @@
1
- from fastapi import FastAPI, HTTPException, Request, UploadFile, File
2
  from fastapi.responses import StreamingResponse
 
3
  from pydantic import BaseModel, Field
4
  from typing import List, Optional, Dict, Any, AsyncGenerator
5
  import asyncio
@@ -11,6 +12,9 @@ from contextlib import asynccontextmanager
11
  import tempfile
12
  import shutil
13
  import random
 
 
 
14
 
15
  # Third-party imports
16
  from openai import OpenAI, AsyncOpenAI
@@ -75,11 +79,170 @@ class Config:
75
  TOP_K = int(os.getenv("TOP_K", "10"))
76
  SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.1"))
77
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # Models
80
  OPENROUTER_MODELS = ["deepseek/deepseek-chat-v3-0324:free", "deepseek/deepseek-r1-0528:free", "qwen/qwen3-235b-a22b:free", "google/gemini-2.0-flash-exp:free"]
81
  GROQ_MODELS = ["llama-3.3-70b-versatile"]
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  class DynamicOpenAIService:
84
  """Service for dynamic OpenAI provider selection"""
85
 
@@ -656,6 +819,19 @@ async def lifespan(app: FastAPI):
656
 
657
  print("πŸš€ All services initialized successfully!")
658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  yield
660
 
661
  # Shutdown
@@ -681,10 +857,15 @@ app = FastAPI(
681
 
682
  @app.get("/")
683
  async def root():
684
- return {"message": "Enhanced RAG API with Dynamic Provider Selection", "status": "running"}
 
 
 
 
 
685
 
686
  @app.get("/health")
687
- async def health_check():
688
  """Health check endpoint"""
689
  try:
690
  # Test Qdrant connection
@@ -755,7 +936,7 @@ async def health_check():
755
  }
756
 
757
  @app.post("/v1/chat/completions")
758
- async def chat_completions(request: ChatCompletionRequest):
759
  """OpenAI-compatible chat completions endpoint with enhanced RAG and dynamic provider selection"""
760
 
761
  if not app_state.openai_service:
@@ -912,7 +1093,11 @@ async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRe
912
 
913
  # Document management endpoints
914
  @app.post("/v1/documents/upload")
915
- async def upload_document(file: UploadFile = File(...), metadata: str = None):
 
 
 
 
916
  """Upload a PDF document"""
917
  try:
918
  if not app_state.document_manager:
@@ -966,7 +1151,7 @@ async def upload_document(file: UploadFile = File(...), metadata: str = None):
966
  raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}")
967
 
968
  @app.post("/v1/documents/search")
969
- async def search_documents(request: DocumentSearchRequest):
970
  """Search for documents"""
971
  try:
972
  if not app_state.document_manager:
@@ -989,7 +1174,7 @@ async def search_documents(request: DocumentSearchRequest):
989
  raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}")
990
 
991
  @app.get("/v1/documents/list")
992
- async def list_documents():
993
  """List all documents"""
994
  try:
995
  if not app_state.document_manager:
@@ -1007,7 +1192,7 @@ async def list_documents():
1007
  raise HTTPException(status_code=500, detail=f"Error listing documents: {str(e)}")
1008
 
1009
  @app.delete("/v1/documents/{document_id}")
1010
- async def delete_document(document_id: str):
1011
  """Delete a document"""
1012
  try:
1013
  if not app_state.document_manager:
@@ -1028,7 +1213,7 @@ async def delete_document(document_id: str):
1028
 
1029
  # Legacy compatibility endpoints
1030
  @app.post("/v1/embeddings/add")
1031
- async def add_document_legacy(content: str, metadata: Optional[Dict] = None):
1032
  """Legacy endpoint for adding documents (text content)"""
1033
  try:
1034
  if not app_state.embedding_service or not app_state.qdrant_client:
@@ -1059,7 +1244,7 @@ async def add_document_legacy(content: str, metadata: Optional[Dict] = None):
1059
  raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}")
1060
 
1061
  @app.get("/v1/collections/info")
1062
- async def get_collection_info():
1063
  """Get information about the collection"""
1064
  try:
1065
  if app_state.qdrant_client is None:
@@ -1079,7 +1264,7 @@ async def get_collection_info():
1079
 
1080
  # New endpoint to get available providers and models
1081
  @app.get("/v1/providers")
1082
- async def get_providers():
1083
  """Get available providers and their models"""
1084
  try:
1085
  if not app_state.openai_service:
@@ -1117,6 +1302,100 @@ async def get_providers():
1117
  except Exception as e:
1118
  raise HTTPException(status_code=500, detail=f"Error getting providers: {str(e)}")
1119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
  if __name__ == "__main__":
1121
  import uvicorn
1122
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Depends, status
2
  from fastapi.responses import StreamingResponse
3
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
  from pydantic import BaseModel, Field
5
  from typing import List, Optional, Dict, Any, AsyncGenerator
6
  import asyncio
 
12
  import tempfile
13
  import shutil
14
  import random
15
+ import hashlib
16
+ import secrets
17
+ from functools import wraps
18
 
19
  # Third-party imports
20
  from openai import OpenAI, AsyncOpenAI
 
79
  TOP_K = int(os.getenv("TOP_K", "10"))
80
  SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.1"))
81
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
82
+
83
+ # Security Configuration
84
+ API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else []
85
+ MASTER_KEY = os.getenv("MASTER_KEY", "")
86
+ ENABLE_SECURITY = os.getenv("ENABLE_SECURITY", "true").lower() == "true"
87
+ RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60"))
88
+
89
+ @classmethod
90
+ def generate_api_key(cls) -> str:
91
+ """Generate a new API key"""
92
+ return f"sk-{secrets.token_urlsafe(32)}"
93
+
94
+ @classmethod
95
+ def validate_api_key(cls, api_key: str) -> bool:
96
+ """Validate API key"""
97
+ if not cls.ENABLE_SECURITY:
98
+ return True
99
+
100
+ if not api_key:
101
+ return False
102
+
103
+ # Check master key
104
+ if cls.MASTER_KEY and api_key == cls.MASTER_KEY:
105
+ return True
106
+
107
+ # Check configured API keys
108
+ if cls.API_KEYS and api_key in cls.API_KEYS:
109
+ return True
110
+
111
+ return False
112
 
113
  # Models
114
  OPENROUTER_MODELS = ["deepseek/deepseek-chat-v3-0324:free", "deepseek/deepseek-r1-0528:free", "qwen/qwen3-235b-a22b:free", "google/gemini-2.0-flash-exp:free"]
115
  GROQ_MODELS = ["llama-3.3-70b-versatile"]
116
 
117
+ # Security Models
118
+ class APIKeyRequest(BaseModel):
119
+ description: Optional[str] = Field(None, description="Description for the API key")
120
+
121
+ class APIKeyResponse(BaseModel):
122
+ api_key: str
123
+ description: Optional[str] = None
124
+ created_at: str
125
+ status: str = "active"
126
+
127
+ class SecurityInfo(BaseModel):
128
+ security_enabled: bool
129
+ rate_limit_per_minute: int
130
+ has_master_key: bool
131
+ configured_keys_count: int
132
+
133
+ # Rate Limiting
134
+ class RateLimiter:
135
+ def __init__(self):
136
+ self.requests = {}
137
+ self.blocked_ips = set()
138
+
139
+ def is_allowed(self, identifier: str, limit_per_minute: int = Config.RATE_LIMIT_PER_MINUTE) -> bool:
140
+ """Check if request is allowed based on rate limit"""
141
+ if not Config.ENABLE_SECURITY:
142
+ return True
143
+
144
+ if identifier in self.blocked_ips:
145
+ return False
146
+
147
+ now = datetime.now()
148
+ minute_key = now.strftime("%Y-%m-%d %H:%M")
149
+
150
+ if identifier not in self.requests:
151
+ self.requests[identifier] = {}
152
+
153
+ if minute_key not in self.requests[identifier]:
154
+ self.requests[identifier][minute_key] = 0
155
+
156
+ # Clean old entries (keep only last 2 minutes)
157
+ keys_to_remove = []
158
+ for key in self.requests[identifier]:
159
+ try:
160
+ key_time = datetime.strptime(key, "%Y-%m-%d %H:%M")
161
+ if (now - key_time).total_seconds() > 120: # 2 minutes
162
+ keys_to_remove.append(key)
163
+ except ValueError:
164
+ keys_to_remove.append(key)
165
+
166
+ for key in keys_to_remove:
167
+ del self.requests[identifier][key]
168
+
169
+ # Check current minute limit
170
+ current_requests = self.requests[identifier].get(minute_key, 0)
171
+ if current_requests >= limit_per_minute:
172
+ return False
173
+
174
+ self.requests[identifier][minute_key] = current_requests + 1
175
+ return True
176
+
177
+ def block_ip(self, ip: str):
178
+ """Block an IP address"""
179
+ self.blocked_ips.add(ip)
180
+
181
+ def unblock_ip(self, ip: str):
182
+ """Unblock an IP address"""
183
+ self.blocked_ips.discard(ip)
184
+
185
+ # Security Dependencies
186
+ security = HTTPBearer(auto_error=False)
187
+ rate_limiter = RateLimiter()
188
+
189
+ async def verify_api_key(
190
+ request: Request,
191
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
192
+ ) -> str:
193
+ """Verify API key from Authorization header"""
194
+ if not Config.ENABLE_SECURITY:
195
+ return "security_disabled"
196
+
197
+ # Get client IP
198
+ client_ip = request.client.host
199
+
200
+ # Check rate limit
201
+ if not rate_limiter.is_allowed(client_ip):
202
+ raise HTTPException(
203
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
204
+ detail="Rate limit exceeded"
205
+ )
206
+
207
+ # Check API key
208
+ if not credentials:
209
+ raise HTTPException(
210
+ status_code=status.HTTP_401_UNAUTHORIZED,
211
+ detail="API key required. Please provide a valid API key in the Authorization header as 'Bearer <your-api-key>'"
212
+ )
213
+
214
+ api_key = credentials.credentials
215
+ if not Config.validate_api_key(api_key):
216
+ raise HTTPException(
217
+ status_code=status.HTTP_401_UNAUTHORIZED,
218
+ detail="Invalid API key"
219
+ )
220
+
221
+ return api_key
222
+
223
+ async def verify_master_key(
224
+ request: Request,
225
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
226
+ ) -> str:
227
+ """Verify master key for admin operations"""
228
+ if not Config.ENABLE_SECURITY:
229
+ return "security_disabled"
230
+
231
+ if not credentials:
232
+ raise HTTPException(
233
+ status_code=status.HTTP_401_UNAUTHORIZED,
234
+ detail="Master key required for admin operations"
235
+ )
236
+
237
+ api_key = credentials.credentials
238
+ if not Config.MASTER_KEY or api_key != Config.MASTER_KEY:
239
+ raise HTTPException(
240
+ status_code=status.HTTP_403_FORBIDDEN,
241
+ detail="Invalid master key"
242
+ )
243
+
244
+ return api_key
245
+
246
  class DynamicOpenAIService:
247
  """Service for dynamic OpenAI provider selection"""
248
 
 
819
 
820
  print("πŸš€ All services initialized successfully!")
821
 
822
+ # Print security information
823
+ if Config.ENABLE_SECURITY:
824
+ print("\nπŸ”’ Security Configuration:")
825
+ print(f" Security: ENABLED")
826
+ print(f" Rate Limit: {Config.RATE_LIMIT_PER_MINUTE} requests/minute")
827
+ print(f" Master Key: {'βœ“ Configured' if Config.MASTER_KEY else 'βœ— Not configured'}")
828
+ print(f" API Keys: {len([k for k in Config.API_KEYS if k.strip()])} configured")
829
+ if not Config.MASTER_KEY and not Config.API_KEYS:
830
+ print(" ⚠️ WARNING: No API keys configured! Set MASTER_KEY or API_KEYS environment variable.")
831
+ else:
832
+ print("\nπŸ”“ Security: DISABLED")
833
+ print(" All endpoints are publicly accessible")
834
+
835
  yield
836
 
837
  # Shutdown
 
857
 
858
  @app.get("/")
859
  async def root():
860
+ return {
861
+ "message": "Enhanced RAG API with Dynamic Provider Selection",
862
+ "status": "running",
863
+ "security_enabled": Config.ENABLE_SECURITY,
864
+ "version": "1.0.0"
865
+ }
866
 
867
  @app.get("/health")
868
+ async def health_check(api_key: str = Depends(verify_api_key)):
869
  """Health check endpoint"""
870
  try:
871
  # Test Qdrant connection
 
936
  }
937
 
938
  @app.post("/v1/chat/completions")
939
+ async def chat_completions(request: ChatCompletionRequest, api_key: str = Depends(verify_api_key)):
940
  """OpenAI-compatible chat completions endpoint with enhanced RAG and dynamic provider selection"""
941
 
942
  if not app_state.openai_service:
 
1093
 
1094
  # Document management endpoints
1095
  @app.post("/v1/documents/upload")
1096
+ async def upload_document(
1097
+ file: UploadFile = File(...),
1098
+ metadata: str = None,
1099
+ api_key: str = Depends(verify_api_key)
1100
+ ):
1101
  """Upload a PDF document"""
1102
  try:
1103
  if not app_state.document_manager:
 
1151
  raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}")
1152
 
1153
  @app.post("/v1/documents/search")
1154
+ async def search_documents(request: DocumentSearchRequest, api_key: str = Depends(verify_api_key)):
1155
  """Search for documents"""
1156
  try:
1157
  if not app_state.document_manager:
 
1174
  raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}")
1175
 
1176
  @app.get("/v1/documents/list")
1177
+ async def list_documents(api_key: str = Depends(verify_api_key)):
1178
  """List all documents"""
1179
  try:
1180
  if not app_state.document_manager:
 
1192
  raise HTTPException(status_code=500, detail=f"Error listing documents: {str(e)}")
1193
 
1194
  @app.delete("/v1/documents/{document_id}")
1195
+ async def delete_document(document_id: str, api_key: str = Depends(verify_api_key)):
1196
  """Delete a document"""
1197
  try:
1198
  if not app_state.document_manager:
 
1213
 
1214
  # Legacy compatibility endpoints
1215
  @app.post("/v1/embeddings/add")
1216
+ async def add_document_legacy(content: str, metadata: Optional[Dict] = None, api_key: str = Depends(verify_api_key)):
1217
  """Legacy endpoint for adding documents (text content)"""
1218
  try:
1219
  if not app_state.embedding_service or not app_state.qdrant_client:
 
1244
  raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}")
1245
 
1246
  @app.get("/v1/collections/info")
1247
+ async def get_collection_info(api_key: str = Depends(verify_api_key)):
1248
  """Get information about the collection"""
1249
  try:
1250
  if app_state.qdrant_client is None:
 
1264
 
1265
  # New endpoint to get available providers and models
1266
  @app.get("/v1/providers")
1267
+ async def get_providers(api_key: str = Depends(verify_api_key)):
1268
  """Get available providers and their models"""
1269
  try:
1270
  if not app_state.openai_service:
 
1302
  except Exception as e:
1303
  raise HTTPException(status_code=500, detail=f"Error getting providers: {str(e)}")
1304
 
1305
+ # Security Management Endpoints
1306
+ @app.get("/v1/security/info")
1307
+ async def get_security_info() -> SecurityInfo:
1308
+ """Get security configuration information (public endpoint)"""
1309
+ return SecurityInfo(
1310
+ security_enabled=Config.ENABLE_SECURITY,
1311
+ rate_limit_per_minute=Config.RATE_LIMIT_PER_MINUTE,
1312
+ has_master_key=bool(Config.MASTER_KEY),
1313
+ configured_keys_count=len([k for k in Config.API_KEYS if k.strip()])
1314
+ )
1315
+
1316
+ @app.post("/v1/security/generate-key")
1317
+ async def generate_api_key(
1318
+ request: APIKeyRequest,
1319
+ master_key: str = Depends(verify_master_key)
1320
+ ) -> APIKeyResponse:
1321
+ """Generate a new API key (requires master key)"""
1322
+ try:
1323
+ new_key = Config.generate_api_key()
1324
+
1325
+ return APIKeyResponse(
1326
+ api_key=new_key,
1327
+ description=request.description,
1328
+ created_at=datetime.now().isoformat(),
1329
+ status="active"
1330
+ )
1331
+ except Exception as e:
1332
+ raise HTTPException(status_code=500, detail=f"Error generating API key: {str(e)}")
1333
+
1334
+ @app.post("/v1/security/validate-key")
1335
+ async def validate_api_key_endpoint(
1336
+ api_key: str = Depends(verify_api_key)
1337
+ ) -> Dict[str, Any]:
1338
+ """Validate an API key"""
1339
+ return {
1340
+ "valid": True,
1341
+ "key_type": "master" if api_key == Config.MASTER_KEY else "standard",
1342
+ "validated_at": datetime.now().isoformat()
1343
+ }
1344
+
1345
+ @app.get("/v1/security/rate-limit-status")
1346
+ async def get_rate_limit_status(
1347
+ request: Request,
1348
+ api_key: str = Depends(verify_api_key)
1349
+ ) -> Dict[str, Any]:
1350
+ """Get current rate limit status"""
1351
+ client_ip = request.client.host
1352
+
1353
+ # Get current minute requests
1354
+ now = datetime.now()
1355
+ minute_key = now.strftime("%Y-%m-%d %H:%M")
1356
+
1357
+ current_requests = 0
1358
+ if client_ip in rate_limiter.requests:
1359
+ current_requests = rate_limiter.requests[client_ip].get(minute_key, 0)
1360
+
1361
+ return {
1362
+ "client_ip": client_ip,
1363
+ "current_requests": current_requests,
1364
+ "limit_per_minute": Config.RATE_LIMIT_PER_MINUTE,
1365
+ "remaining_requests": max(0, Config.RATE_LIMIT_PER_MINUTE - current_requests),
1366
+ "reset_at": f"{minute_key}:00",
1367
+ "is_blocked": client_ip in rate_limiter.blocked_ips
1368
+ }
1369
+
1370
+ # Admin endpoints for IP management
1371
+ @app.post("/v1/admin/block-ip/{ip}")
1372
+ async def block_ip(
1373
+ ip: str,
1374
+ master_key: str = Depends(verify_master_key)
1375
+ ) -> Dict[str, str]:
1376
+ """Block an IP address (requires master key)"""
1377
+ rate_limiter.block_ip(ip)
1378
+ return {"message": f"IP {ip} has been blocked", "blocked_at": datetime.now().isoformat()}
1379
+
1380
+ @app.post("/v1/admin/unblock-ip/{ip}")
1381
+ async def unblock_ip(
1382
+ ip: str,
1383
+ master_key: str = Depends(verify_master_key)
1384
+ ) -> Dict[str, str]:
1385
+ """Unblock an IP address (requires master key)"""
1386
+ rate_limiter.unblock_ip(ip)
1387
+ return {"message": f"IP {ip} has been unblocked", "unblocked_at": datetime.now().isoformat()}
1388
+
1389
+ @app.get("/v1/admin/blocked-ips")
1390
+ async def get_blocked_ips(
1391
+ master_key: str = Depends(verify_master_key)
1392
+ ) -> Dict[str, Any]:
1393
+ """Get list of blocked IPs (requires master key)"""
1394
+ return {
1395
+ "blocked_ips": list(rate_limiter.blocked_ips),
1396
+ "count": len(rate_limiter.blocked_ips)
1397
+ }
1398
+
1399
  if __name__ == "__main__":
1400
  import uvicorn
1401
  uvicorn.run(app, host="0.0.0.0", port=8000)