PyScoutAI commited on
Commit
5bfdf79
·
verified ·
1 Parent(s): f03e4f0

Update pyscout_api.py

Browse files
Files changed (1) hide show
  1. pyscout_api.py +442 -438
pyscout_api.py CHANGED
@@ -1,438 +1,442 @@
1
- import asyncio
2
- import time
3
- import uuid
4
- import os
5
- import json
6
- from typing import Dict, List, Optional, Union, Any
7
- from fastapi import FastAPI, HTTPException, Depends, Request, status, Body
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import JSONResponse, StreamingResponse
10
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
- from pydantic import BaseModel, Field, EmailStr
12
- from slowapi import Limiter, _rate_limit_exceeded_handler
13
- from slowapi.util import get_remote_address
14
- from slowapi.errors import RateLimitExceeded
15
- import uvicorn
16
-
17
- from db_helper import MongoDBHelper
18
- from deepinfra_client import DeepInfraClient
19
- from hf_utils import HuggingFaceSpaceHelper
20
-
21
- # Initialize Hugging Face Space helper
22
- hf_helper = HuggingFaceSpaceHelper()
23
-
24
- # Install required packages for HF Spaces if needed
25
- if hf_helper.is_in_space:
26
- hf_helper.install_dependencies([
27
- "pymongo", "python-dotenv", "fastapi", "uvicorn", "slowapi",
28
- "fake-useragent", "requests-ip-rotator", "pydantic[email]"
29
- ])
30
-
31
- # Initialize FastAPI app
32
- app = FastAPI(
33
- title="PyScoutAI API",
34
- description="An OpenAI-compatible API that provides access to DeepInfra models with enhanced features",
35
- version="1.0.0"
36
- )
37
-
38
- # Setup rate limiting
39
- limiter = Limiter(key_func=get_remote_address)
40
- app.state.limiter = limiter
41
- app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
42
-
43
- # Set up CORS
44
- app.add_middleware(
45
- CORSMiddleware,
46
- allow_origins=["*"],
47
- allow_credentials=True,
48
- allow_methods=["*"],
49
- allow_headers=["*"],
50
- )
51
-
52
- # Security
53
- security = HTTPBearer(auto_error=False)
54
-
55
- # Database helper
56
- try:
57
- db = MongoDBHelper(hf_helper.get_mongodb_uri())
58
- except Exception as e:
59
- print(f"Warning: MongoDB connection failed: {e}")
60
- print("API key authentication will not work!")
61
- db = None
62
-
63
- # Models for requests and responses
64
- class Message(BaseModel):
65
- role: str
66
- content: Optional[str] = None
67
- name: Optional[str] = None
68
-
69
- class ChatCompletionRequest(BaseModel):
70
- model: str
71
- messages: List[Message]
72
- temperature: Optional[float] = 0.7
73
- top_p: Optional[float] = 1.0
74
- n: Optional[int] = 1
75
- stream: Optional[bool] = False
76
- max_tokens: Optional[int] = None
77
- presence_penalty: Optional[float] = 0.0
78
- frequency_penalty: Optional[float] = 0.0
79
- user: Optional[str] = None
80
-
81
- class CompletionRequest(BaseModel):
82
- model: str
83
- prompt: Union[str, List[str]]
84
- temperature: Optional[float] = 0.7
85
- top_p: Optional[float] = 1.0
86
- n: Optional[int] = 1
87
- stream: Optional[bool] = False
88
- max_tokens: Optional[int] = None
89
- presence_penalty: Optional[float] = 0.0
90
- frequency_penalty: Optional[float] = 0.0
91
- user: Optional[str] = None
92
-
93
- class UserCreate(BaseModel):
94
- email: EmailStr
95
- name: str
96
- organization: Optional[str] = None
97
-
98
- class APIKeyCreate(BaseModel):
99
- name: str = "Default API Key"
100
- user_id: str
101
-
102
- class APIKeyResponse(BaseModel):
103
- key: str
104
- name: str
105
- created_at: str
106
-
107
- # API clients storage (one per API key)
108
- clients: Dict[str, DeepInfraClient] = {}
109
-
110
- # Helper function to get the API key from the request
111
- async def get_api_key(
112
- request: Request,
113
- credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
114
- ) -> Optional[str]:
115
- # Check Authorization header
116
- if credentials:
117
- return credentials.credentials
118
-
119
- # Check for API key in the request headers
120
- if "Authorization" in request.headers:
121
- auth = request.headers["Authorization"]
122
- if auth.startswith("Bearer "):
123
- return auth.replace("Bearer ", "")
124
-
125
- if "x-api-key" in request.headers:
126
- return request.headers["x-api-key"]
127
-
128
- # Check for API key in query parameters
129
- api_key = request.query_params.get("api_key")
130
- if api_key:
131
- return api_key
132
-
133
- # No API key found, return None
134
- return None
135
-
136
- # Helper function to validate a PyScout API key and get user info
137
- async def get_user_info(api_key: Optional[str] = Depends(get_api_key)) -> Dict[str, Any]:
138
- if not api_key:
139
- raise HTTPException(
140
- status_code=status.HTTP_401_UNAUTHORIZED,
141
- detail="API key is required",
142
- headers={"WWW-Authenticate": "Bearer"}
143
- )
144
-
145
- # Skip validation if DB is not connected (development mode)
146
- if not db:
147
- return {"user_id": "development", "key": api_key}
148
-
149
- # Check if key starts with PyScoutAI-
150
- if not api_key.startswith("PyScoutAI-"):
151
- raise HTTPException(
152
- status_code=status.HTTP_401_UNAUTHORIZED,
153
- detail="Invalid API key format",
154
- headers={"WWW-Authenticate": "Bearer"}
155
- )
156
-
157
- # Validate the API key
158
- user_info = db.validate_api_key(api_key)
159
- if not user_info:
160
- raise HTTPException(
161
- status_code=status.HTTP_401_UNAUTHORIZED,
162
- detail="Invalid API key",
163
- headers={"WWW-Authenticate": "Bearer"}
164
- )
165
-
166
- # Check rate limits
167
- rate_limit = db.check_rate_limit(api_key)
168
- if not rate_limit["allowed"]:
169
- raise HTTPException(
170
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
171
- detail=rate_limit["reason"]
172
- )
173
-
174
- return user_info
175
-
176
- # Helper function to get or create a client
177
- def get_client(api_key: str) -> DeepInfraClient:
178
- if api_key not in clients:
179
- # Create a client with IP rotation and random user agent
180
- clients[api_key] = DeepInfraClient(
181
- use_random_user_agent=True,
182
- use_proxy_rotation=True,
183
- use_ip_rotation=True
184
- )
185
- return clients[api_key]
186
-
187
- @app.get("/")
188
- async def root():
189
- metadata = hf_helper.get_hf_metadata()
190
- return {
191
- "message": "Welcome to PyScoutAI API",
192
- "documentation": "/docs",
193
- "environment": "Hugging Face Space" if hf_helper.is_in_space else "Local",
194
- "endpoints": [
195
- "/v1/models",
196
- "/v1/chat/completions",
197
- "/v1/completions"
198
- ],
199
- **metadata
200
- }
201
-
202
- @app.get("/v1/models")
203
- @limiter.limit("20/minute")
204
- async def list_models(
205
- request: Request,
206
- user_info: Dict[str, Any] = Depends(get_user_info)
207
- ):
208
- api_key = user_info["key"]
209
- client = get_client(api_key)
210
- try:
211
- models = await asyncio.to_thread(client.models.list)
212
- # Log the API usage
213
- if db:
214
- db.log_api_usage(api_key, "/v1/models", 0)
215
- return models
216
- except Exception as e:
217
- raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}")
218
-
219
- @app.post("/v1/chat/completions")
220
- @limiter.limit("60/minute")
221
- async def create_chat_completion(
222
- request: Request,
223
- body: ChatCompletionRequest,
224
- user_info: Dict[str, Any] = Depends(get_user_info)
225
- ):
226
- api_key = user_info["key"]
227
- client = get_client(api_key)
228
-
229
- try:
230
- # Prepare the messages
231
- messages = [{"role": msg.role, "content": msg.content} for msg in body.messages if msg.content is not None]
232
-
233
- kwargs = {
234
- "model": body.model,
235
- "temperature": body.temperature,
236
- "max_tokens": body.max_tokens,
237
- "stream": body.stream,
238
- "top_p": body.top_p,
239
- "presence_penalty": body.presence_penalty,
240
- "frequency_penalty": body.frequency_penalty,
241
- }
242
-
243
- if body.stream:
244
- async def generate_stream():
245
- response_stream = await asyncio.to_thread(
246
- client.chat.create,
247
- messages=messages,
248
- **kwargs
249
- )
250
-
251
- total_tokens = 0
252
- for chunk in response_stream:
253
- # Track token usage for each chunk if available
254
- if 'usage' in chunk and chunk['usage']:
255
- total_tokens += chunk['usage'].get('total_tokens', 0)
256
-
257
- yield f"data: {json.dumps(chunk)}\n\n"
258
-
259
- # Log API usage at the end of streaming
260
- if db:
261
- db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
262
-
263
- yield "data: [DONE]\n\n"
264
-
265
- return StreamingResponse(
266
- generate_stream(),
267
- media_type="text/event-stream"
268
- )
269
- else:
270
- response = await asyncio.to_thread(
271
- client.chat.create,
272
- messages=messages,
273
- **kwargs
274
- )
275
-
276
- # Log the API usage
277
- if db and 'usage' in response:
278
- total_tokens = response['usage'].get('total_tokens', 0)
279
- db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
280
-
281
- return response
282
-
283
- except Exception as e:
284
- raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
285
-
286
- @app.post("/v1/completions")
287
- @limiter.limit("60/minute")
288
- async def create_completion(
289
- request: Request,
290
- body: CompletionRequest,
291
- user_info: Dict[str, Any] = Depends(get_user_info)
292
- ):
293
- api_key = user_info["key"]
294
- client = get_client(api_key)
295
-
296
- try:
297
- # Handle different prompt types
298
- prompt = body.prompt
299
- if isinstance(prompt, list):
300
- prompt = prompt[0] # Take the first prompt if it's a list
301
-
302
- kwargs = {
303
- "model": body.model,
304
- "temperature": body.temperature,
305
- "max_tokens": body.max_tokens,
306
- "stream": body.stream,
307
- "top_p": body.top_p,
308
- "presence_penalty": body.presence_penalty,
309
- "frequency_penalty": body.frequency_penalty,
310
- }
311
-
312
- if body.stream:
313
- async def generate_stream():
314
- response_stream = await asyncio.to_thread(
315
- client.completions.create,
316
- prompt=prompt,
317
- **kwargs
318
- )
319
-
320
- total_tokens = 0
321
- for chunk in response_stream:
322
- if 'usage' in chunk and chunk['usage']:
323
- total_tokens += chunk['usage'].get('total_tokens', 0)
324
-
325
- yield f"data: {json.dumps(chunk)}\n\n"
326
-
327
- # Log API usage at the end of streaming
328
- if db:
329
- db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
330
-
331
- yield "data: [DONE]\n\n"
332
-
333
- return StreamingResponse(
334
- generate_stream(),
335
- media_type="text/event-stream"
336
- )
337
- else:
338
- response = await asyncio.to_thread(
339
- client.completions.create,
340
- prompt=prompt,
341
- **kwargs
342
- )
343
-
344
- # Log the API usage
345
- if db and 'usage' in response:
346
- total_tokens = response['usage'].get('total_tokens', 0)
347
- db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
348
-
349
- return response
350
-
351
- except Exception as e:
352
- raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}")
353
-
354
- @app.get("/health")
355
- async def health_check():
356
- status_info = {"api": "ok"}
357
-
358
- # Check MongoDB connection
359
- if db:
360
- try:
361
- # Simple operation to check connection
362
- db.api_keys_collection.find_one({})
363
- status_info["database"] = "ok"
364
- except Exception as e:
365
- status_info["database"] = f"error: {str(e)}"
366
- else:
367
- status_info["database"] = "not configured"
368
-
369
- # Add Hugging Face Space info
370
- if hf_helper.is_in_space:
371
- status_info["environment"] = "Hugging Face Space"
372
- status_info["space_name"] = hf_helper.space_name
373
- else:
374
- status_info["environment"] = "Local"
375
-
376
- return status_info
377
-
378
- # API Key Management Endpoints
379
- @app.post("/v1/api_keys", response_model=APIKeyResponse)
380
- async def create_api_key(body: APIKeyCreate):
381
- if not db:
382
- raise HTTPException(status_code=500, detail="Database not configured")
383
-
384
- try:
385
- api_key = db.generate_api_key(body.user_id, body.name)
386
- key_data = db.validate_api_key(api_key)
387
- return {
388
- "key": api_key,
389
- "name": key_data["name"],
390
- "created_at": key_data["created_at"].isoformat()
391
- }
392
- except Exception as e:
393
- raise HTTPException(status_code=500, detail=f"Error creating API key: {str(e)}")
394
-
395
- @app.get("/v1/api_keys")
396
- async def list_api_keys(user_id: str):
397
- if not db:
398
- raise HTTPException(status_code=500, detail="Database not configured")
399
-
400
- keys = db.get_user_api_keys(user_id)
401
- for key in keys:
402
- if "created_at" in key:
403
- key["created_at"] = key["created_at"].isoformat()
404
- if "last_used" in key and key["last_used"]:
405
- key["last_used"] = key["last_used"].isoformat()
406
-
407
- return {"keys": keys}
408
-
409
- @app.post("/v1/api_keys/revoke")
410
- async def revoke_api_key(api_key: str):
411
- if not db:
412
- raise HTTPException(status_code=500, detail="Database not configured")
413
-
414
- success = db.revoke_api_key(api_key)
415
- if not success:
416
- raise HTTPException(status_code=404, detail="API key not found")
417
-
418
- return {"message": "API key revoked successfully"}
419
-
420
- # Clean up IP rotator clients on shutdown
421
- @app.on_event("shutdown")
422
- async def cleanup_clients():
423
- for client in clients.values():
424
- try:
425
- if hasattr(client, 'ip_rotator') and client.ip_rotator:
426
- client.ip_rotator.shutdown()
427
- except:
428
- pass
429
-
430
- if __name__ == "__main__":
431
- # Get host and port based on environment
432
- host = hf_helper.get_hostname()
433
- port = 8000 # Default port for FastAPI
434
-
435
- print(f"Starting PyScoutAI API on http://{host}:{port}")
436
- print(f"Environment: {'Hugging Face Space' if hf_helper.is_in_space else 'Local'}")
437
-
438
- uvicorn.run("pyscout_api:app", host=host, port=port, reload=not hf_helper.is_in_space)
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ import uuid
4
+ import os
5
+ import json
6
+ from typing import Dict, List, Optional, Union, Any
7
+ from fastapi import FastAPI, HTTPException, Depends, Request, status, Body
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import JSONResponse, StreamingResponse
10
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
+ from pydantic import BaseModel, Field, EmailStr
12
+ from slowapi import Limiter, _rate_limit_exceeded_handler
13
+ from slowapi.util import get_remote_address
14
+ from slowapi.errors import RateLimitExceeded
15
+ import uvicorn
16
+
17
+ from db_helper import MongoDBHelper
18
+ from deepinfra_client import DeepInfraClient
19
+ from hf_utils import HuggingFaceSpaceHelper
20
+
21
+ # Initialize Hugging Face Space helper
22
+ hf_helper = HuggingFaceSpaceHelper()
23
+
24
+ # Install required packages for HF Spaces if needed
25
+ if hf_helper.is_in_space:
26
+ hf_helper.install_dependencies([
27
+ "pymongo", "python-dotenv", "fastapi", "uvicorn", "slowapi",
28
+ "fake-useragent", "requests-ip-rotator", "pydantic[email]"
29
+ ])
30
+
31
+ # Initialize FastAPI app
32
+ app = FastAPI(
33
+ title="PyScoutAI API",
34
+ description="An OpenAI-compatible API that provides access to DeepInfra models with enhanced features",
35
+ version="1.0.0"
36
+ )
37
+
38
+ # Setup rate limiting
39
+ limiter = Limiter(key_func=get_remote_address)
40
+ app.state.limiter = limiter
41
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
42
+
43
+ # Set up CORS
44
+ app.add_middleware(
45
+ CORSMiddleware,
46
+ allow_origins=["*"],
47
+ allow_credentials=True,
48
+ allow_methods=["*"],
49
+ allow_headers=["*"],
50
+ )
51
+
52
+ # Security
53
+ security = HTTPBearer(auto_error=False)
54
+
55
+ # Database helper
56
+ try:
57
+ db = MongoDBHelper(hf_helper.get_mongodb_uri())
58
+ except Exception as e:
59
+ print(f"Warning: MongoDB connection failed: {e}")
60
+ print("API key authentication will not work!")
61
+ db = None
62
+
63
+ # Models for requests and responses
64
+ class Message(BaseModel):
65
+ role: str
66
+ content: Optional[str] = None
67
+ name: Optional[str] = None
68
+
69
+ class ChatCompletionRequest(BaseModel):
70
+ model: str
71
+ messages: List[Message]
72
+ temperature: Optional[float] = 0.7
73
+ top_p: Optional[float] = 1.0
74
+ n: Optional[int] = 1
75
+ stream: Optional[bool] = False
76
+ max_tokens: Optional[int] = None
77
+ presence_penalty: Optional[float] = 0.0
78
+ frequency_penalty: Optional[float] = 0.0
79
+ user: Optional[str] = None
80
+
81
+ class CompletionRequest(BaseModel):
82
+ model: str
83
+ prompt: Union[str, List[str]]
84
+ temperature: Optional[float] = 0.7
85
+ top_p: Optional[float] = 1.0
86
+ n: Optional[int] = 1
87
+ stream: Optional[bool] = False
88
+ max_tokens: Optional[int] = None
89
+ presence_penalty: Optional[float] = 0.0
90
+ frequency_penalty: Optional[float] = 0.0
91
+ user: Optional[str] = None
92
+
93
+ class UserCreate(BaseModel):
94
+ email: EmailStr
95
+ name: str
96
+ organization: Optional[str] = None
97
+
98
+ class APIKeyCreate(BaseModel):
99
+ name: str = "Default API Key"
100
+ user_id: str
101
+
102
+ class APIKeyResponse(BaseModel):
103
+ key: str
104
+ name: str
105
+ created_at: str
106
+
107
+ # API clients storage (one per API key)
108
+ clients: Dict[str, DeepInfraClient] = {}
109
+
110
+ # Helper function to get the API key from the request
111
+ async def get_api_key(
112
+ request: Request,
113
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
114
+ ) -> Optional[str]:
115
+ # Check Authorization header
116
+ if credentials:
117
+ return credentials.credentials
118
+
119
+ # Check for API key in the request headers
120
+ if "Authorization" in request.headers:
121
+ auth = request.headers["Authorization"]
122
+ if auth.startswith("Bearer "):
123
+ return auth.replace("Bearer ", "")
124
+
125
+ if "x-api-key" in request.headers:
126
+ return request.headers["x-api-key"]
127
+
128
+ # Check for API key in query parameters
129
+ api_key = request.query_params.get("api_key")
130
+ if api_key:
131
+ return api_key
132
+
133
+ # No API key found, return None
134
+ return None
135
+
136
+ # Helper function to validate a PyScout API key and get user info
137
+ async def get_user_info(api_key: Optional[str] = Depends(get_api_key)) -> Dict[str, Any]:
138
+ if not api_key:
139
+ raise HTTPException(
140
+ status_code=status.HTTP_401_UNAUTHORIZED,
141
+ detail="API key is required",
142
+ headers={"WWW-Authenticate": "Bearer"}
143
+ )
144
+
145
+ # Skip validation if DB is not connected (development mode)
146
+ if not db:
147
+ return {"user_id": "development", "key": api_key}
148
+
149
+ # Check if key starts with PyScoutAI-
150
+ if not api_key.startswith("PyScoutAI-"):
151
+ raise HTTPException(
152
+ status_code=status.HTTP_401_UNAUTHORIZED,
153
+ detail="Invalid API key format",
154
+ headers={"WWW-Authenticate": "Bearer"}
155
+ )
156
+
157
+ # Validate the API key
158
+ user_info = db.validate_api_key(api_key)
159
+ if not user_info:
160
+ raise HTTPException(
161
+ status_code=status.HTTP_401_UNAUTHORIZED,
162
+ detail="Invalid API key",
163
+ headers={"WWW-Authenticate": "Bearer"}
164
+ )
165
+
166
+ # Check rate limits
167
+ rate_limit = db.check_rate_limit(api_key)
168
+ if not rate_limit["allowed"]:
169
+ raise HTTPException(
170
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
171
+ detail=rate_limit["reason"]
172
+ )
173
+
174
+ return user_info
175
+
176
+ # Helper function to get or create a client
177
+ def get_client(api_key: str) -> DeepInfraClient:
178
+ if api_key not in clients:
179
+ # Create a client with IP rotation and random user agent
180
+ clients[api_key] = DeepInfraClient(
181
+ use_random_user_agent=True,
182
+ use_proxy_rotation=True,
183
+ use_ip_rotation=True
184
+ )
185
+ return clients[api_key]
186
+
187
+ @app.get("/")
188
+ async def root():
189
+ metadata = hf_helper.get_hf_metadata()
190
+ return {
191
+ "message": "Welcome to PyScoutAI API",
192
+ "documentation": "/docs",
193
+ "environment": "Hugging Face Space" if hf_helper.is_in_space else "Local",
194
+ "endpoints": [
195
+ "/v1/models",
196
+ "/v1/chat/completions",
197
+ "/v1/completions"
198
+ ],
199
+ **metadata
200
+ }
201
+
202
+ @app.get("/v1/models")
203
+ @limiter.limit("20/minute")
204
+ async def list_models(
205
+ request: Request,
206
+ user_info: Dict[str, Any] = Depends(get_user_info)
207
+ ):
208
+ api_key = user_info["key"]
209
+ client = get_client(api_key)
210
+ try:
211
+ models = await asyncio.to_thread(client.models.list)
212
+ # Log the API usage
213
+ if db:
214
+ db.log_api_usage(api_key, "/v1/models", 0)
215
+ return models
216
+ except Exception as e:
217
+ raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}")
218
+
219
+ @app.post("/v1/chat/completions")
220
+ @limiter.limit("60/minute")
221
+ async def create_chat_completion(
222
+ request: Request,
223
+ body: ChatCompletionRequest,
224
+ user_info: Dict[str, Any] = Depends(get_user_info)
225
+ ):
226
+ api_key = user_info["key"]
227
+ client = get_client(api_key)
228
+
229
+ try:
230
+ # Prepare the messages
231
+ messages = [{"role": msg.role, "content": msg.content} for msg in body.messages if msg.content is not None]
232
+
233
+ kwargs = {
234
+ "model": body.model,
235
+ "temperature": body.temperature,
236
+ "max_tokens": body.max_tokens,
237
+ "stream": body.stream,
238
+ "top_p": body.top_p,
239
+ "presence_penalty": body.presence_penalty,
240
+ "frequency_penalty": body.frequency_penalty,
241
+ }
242
+
243
+ if body.stream:
244
+ async def generate_stream():
245
+ response_stream = await asyncio.to_thread(
246
+ client.chat.create,
247
+ messages=messages,
248
+ **kwargs
249
+ )
250
+
251
+ total_tokens = 0
252
+ for chunk in response_stream:
253
+ # Track token usage for each chunk if available
254
+ if 'usage' in chunk and chunk['usage']:
255
+ total_tokens += chunk['usage'].get('total_tokens', 0)
256
+
257
+ yield f"data: {json.dumps(chunk)}\n\n"
258
+
259
+ # Log API usage at the end of streaming
260
+ if db:
261
+ db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
262
+
263
+ yield "data: [DONE]\n\n"
264
+
265
+ return StreamingResponse(
266
+ generate_stream(),
267
+ media_type="text/event-stream"
268
+ )
269
+ else:
270
+ response = await asyncio.to_thread(
271
+ client.chat.create,
272
+ messages=messages,
273
+ **kwargs
274
+ )
275
+
276
+ # Log the API usage
277
+ if db and 'usage' in response:
278
+ total_tokens = response['usage'].get('total_tokens', 0)
279
+ db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
280
+
281
+ return response
282
+
283
+ except Exception as e:
284
+ raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
285
+
286
+ @app.post("/v1/completions")
287
+ @limiter.limit("60/minute")
288
+ async def create_completion(
289
+ request: Request,
290
+ body: CompletionRequest,
291
+ user_info: Dict[str, Any] = Depends(get_user_info)
292
+ ):
293
+ api_key = user_info["key"]
294
+ client = get_client(api_key)
295
+
296
+ try:
297
+ # Handle different prompt types
298
+ prompt = body.prompt
299
+ if isinstance(prompt, list):
300
+ prompt = prompt[0] # Take the first prompt if it's a list
301
+
302
+ kwargs = {
303
+ "model": body.model,
304
+ "temperature": body.temperature,
305
+ "max_tokens": body.max_tokens,
306
+ "stream": body.stream,
307
+ "top_p": body.top_p,
308
+ "presence_penalty": body.presence_penalty,
309
+ "frequency_penalty": body.frequency_penalty,
310
+ }
311
+
312
+ if body.stream:
313
+ async def generate_stream():
314
+ response_stream = await asyncio.to_thread(
315
+ client.completions.create,
316
+ prompt=prompt,
317
+ **kwargs
318
+ )
319
+
320
+ total_tokens = 0
321
+ for chunk in response_stream:
322
+ if 'usage' in chunk and chunk['usage']:
323
+ total_tokens += chunk['usage'].get('total_tokens', 0)
324
+
325
+ yield f"data: {json.dumps(chunk)}\n\n"
326
+
327
+ # Log API usage at the end of streaming
328
+ if db:
329
+ db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
330
+
331
+ yield "data: [DONE]\n\n"
332
+
333
+ return StreamingResponse(
334
+ generate_stream(),
335
+ media_type="text/event-stream"
336
+ )
337
+ else:
338
+ response = await asyncio.to_thread(
339
+ client.completions.create,
340
+ prompt=prompt,
341
+ **kwargs
342
+ )
343
+
344
+ # Log the API usage
345
+ if db and 'usage' in response:
346
+ total_tokens = response['usage'].get('total_tokens', 0)
347
+ db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
348
+
349
+ return response
350
+
351
+ except Exception as e:
352
+ raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}")
353
+
354
+ @app.get("/health")
355
+ async def health_check():
356
+ status_info = {"api": "ok"}
357
+
358
+ # Check MongoDB connection
359
+ if db:
360
+ try:
361
+ # Simple operation to check connection
362
+ db.api_keys_collection.find_one({})
363
+ status_info["database"] = "ok"
364
+ except Exception as e:
365
+ status_info["database"] = f"error: {str(e)}"
366
+ else:
367
+ status_info["database"] = "not configured"
368
+
369
+ # Add Hugging Face Space info
370
+ if hf_helper.is_in_space:
371
+ status_info["environment"] = "Hugging Face Space"
372
+ status_info["space_name"] = hf_helper.space_name
373
+ else:
374
+ status_info["environment"] = "Local"
375
+
376
+ return status_info
377
+
378
+ # API Key Management Endpoints
379
+ @app.post("/v1/api_keys", response_model=APIKeyResponse)
380
+ async def create_api_key(body: APIKeyCreate):
381
+ if not db:
382
+ raise HTTPException(status_code=500, detail="Database not configured")
383
+
384
+ try:
385
+ api_key = db.generate_api_key(body.user_id, body.name)
386
+ key_data = db.validate_api_key(api_key)
387
+ return {
388
+ "key": api_key,
389
+ "name": key_data["name"],
390
+ "created_at": key_data["created_at"].isoformat()
391
+ }
392
+ except Exception as e:
393
+ raise HTTPException(status_code=500, detail=f"Error creating API key: {str(e)}")
394
+
395
+ @app.get("/v1/api_keys")
396
+ async def list_api_keys(user_id: str):
397
+ if not db:
398
+ raise HTTPException(status_code=500, detail="Database not configured")
399
+
400
+ keys = db.get_user_api_keys(user_id)
401
+ for key in keys:
402
+ if "created_at" in key:
403
+ key["created_at"] = key["created_at"].isoformat()
404
+ if "last_used" in key and key["last_used"]:
405
+ key["last_used"] = key["last_used"].isoformat()
406
+
407
+ return {"keys": keys}
408
+
409
+ @app.post("/v1/api_keys/revoke")
410
+ async def revoke_api_key(api_key: str):
411
+ if not db:
412
+ raise HTTPException(status_code=500, detail="Database not configured")
413
+
414
+ success = db.revoke_api_key(api_key)
415
+ if not success:
416
+ raise HTTPException(status_code=404, detail="API key not found")
417
+
418
+ return {"message": "API key revoked successfully"}
419
+
420
+ # Clean up IP rotator clients on shutdown
421
+ @app.on_event("shutdown")
422
+ async def cleanup_clients():
423
+ for client in clients.values():
424
+ try:
425
+ if hasattr(client, 'ip_rotator') and client.ip_rotator:
426
+ client.ip_rotator.shutdown()
427
+ except:
428
+ pass
429
+
430
+ f __name__ == "__main__":
431
+ host = os.environ.get("HOST", "0.0.0.0")
432
+ port = int(os.environ.get("PORT", "7860"))
433
+
434
+ print(f"Starting PyScoutAI API on http://{host}:{port}")
435
+ print(f"Environment: {'Hugging Face Space' if hf_helper.is_in_space else 'Local'}")
436
+
437
+ uvicorn.run(
438
+ app,
439
+ host=host,
440
+ port=port,
441
+ reload=not hf_helper.is_in_space
442
+ )