Update main.py
Browse files
main.py
CHANGED
@@ -473,38 +473,6 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|
473 |
except Exception as e:
|
474 |
yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"
|
475 |
|
476 |
-
# FastAPI app setup
|
477 |
-
app = FastAPI()
|
478 |
-
|
479 |
-
# Add the cleanup task when the app starts
|
480 |
-
@app.on_event("startup")
|
481 |
-
async def startup_event():
|
482 |
-
asyncio.create_task(cleanup_rate_limit_stores())
|
483 |
-
logger.info("Started rate limit store cleanup task.")
|
484 |
-
|
485 |
-
# Middleware to enhance security and enforce Content-Type for specific endpoints
|
486 |
-
@app.middleware("http")
|
487 |
-
async def security_middleware(request: Request, call_next):
|
488 |
-
client_ip = request.client.host
|
489 |
-
# Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
|
490 |
-
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
491 |
-
content_type = request.headers.get("Content-Type")
|
492 |
-
if content_type != "application/json":
|
493 |
-
logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
|
494 |
-
return JSONResponse(
|
495 |
-
status_code=400,
|
496 |
-
content={
|
497 |
-
"error": {
|
498 |
-
"message": "Content-Type must be application/json",
|
499 |
-
"type": "invalid_request_error",
|
500 |
-
"param": None,
|
501 |
-
"code": None
|
502 |
-
}
|
503 |
-
},
|
504 |
-
)
|
505 |
-
response = await call_next(request)
|
506 |
-
return response
|
507 |
-
|
508 |
# Request Models
|
509 |
class Message(BaseModel):
|
510 |
role: str
|
@@ -556,6 +524,38 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
|
|
556 |
"usage": None, # To be filled in non-streaming responses
|
557 |
}
|
558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
# FastAPI Endpoints
|
560 |
|
561 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
|
@@ -610,7 +610,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
610 |
yield f"data: {json.dumps(response_chunk)}\n\n"
|
611 |
|
612 |
# After all chunks are sent, send the final message with finish_reason
|
613 |
-
prompt_tokens = sum(len(msg
|
614 |
completion_tokens = len(assistant_content.split())
|
615 |
total_tokens = prompt_tokens + completion_tokens
|
616 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
@@ -695,7 +695,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
695 |
logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
|
696 |
raise HTTPException(status_code=500, detail=str(e))
|
697 |
|
698 |
-
#
|
699 |
@app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
|
700 |
async def tokenizer(request: TokenizerRequest, req: Request):
|
701 |
client_ip = req.client.host
|
@@ -704,14 +704,14 @@ async def tokenizer(request: TokenizerRequest, req: Request):
|
|
704 |
logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
|
705 |
return {"text": text, "tokens": token_count}
|
706 |
|
707 |
-
#
|
708 |
@app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
|
709 |
async def get_models(req: Request):
|
710 |
client_ip = req.client.host
|
711 |
logger.info(f"Fetching available models from IP: {client_ip}")
|
712 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
713 |
|
714 |
-
#
|
715 |
@app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
|
716 |
async def model_status(model: str, req: Request):
|
717 |
client_ip = req.client.host
|
@@ -725,21 +725,21 @@ async def model_status(model: str, req: Request):
|
|
725 |
logger.warning(f"Model not found: {model} from IP: {client_ip}")
|
726 |
raise HTTPException(status_code=404, detail="Model not found")
|
727 |
|
728 |
-
#
|
729 |
@app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
|
730 |
async def health_check(req: Request):
|
731 |
client_ip = req.client.host
|
732 |
logger.info(f"Health check requested from IP: {client_ip}")
|
733 |
return {"status": "ok"}
|
734 |
|
735 |
-
#
|
736 |
@app.get("/v1/chat/completions")
|
737 |
async def chat_completions_get(req: Request):
|
738 |
client_ip = req.client.host
|
739 |
logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
|
740 |
return RedirectResponse(url='about:blank')
|
741 |
|
742 |
-
# Custom
|
743 |
@app.exception_handler(HTTPException)
|
744 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
745 |
client_ip = request.client.host
|
|
|
473 |
except Exception as e:
|
474 |
yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"
|
475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
# Request Models
|
477 |
class Message(BaseModel):
|
478 |
role: str
|
|
|
524 |
"usage": None, # To be filled in non-streaming responses
|
525 |
}
|
526 |
|
527 |
+
# Initialize FastAPI app
|
528 |
+
app = FastAPI()
|
529 |
+
|
530 |
+
# Add the cleanup task when the app starts
|
531 |
+
@app.on_event("startup")
|
532 |
+
async def startup_event():
|
533 |
+
asyncio.create_task(cleanup_rate_limit_stores())
|
534 |
+
logger.info("Started rate limit store cleanup task.")
|
535 |
+
|
536 |
+
# Middleware to enhance security and enforce Content-Type for specific endpoints
|
537 |
+
@app.middleware("http")
|
538 |
+
async def security_middleware(request: Request, call_next):
|
539 |
+
client_ip = request.client.host
|
540 |
+
# Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
|
541 |
+
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
542 |
+
content_type = request.headers.get("Content-Type")
|
543 |
+
if content_type != "application/json":
|
544 |
+
logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
|
545 |
+
return JSONResponse(
|
546 |
+
status_code=400,
|
547 |
+
content={
|
548 |
+
"error": {
|
549 |
+
"message": "Content-Type must be application/json",
|
550 |
+
"type": "invalid_request_error",
|
551 |
+
"param": None,
|
552 |
+
"code": None
|
553 |
+
}
|
554 |
+
},
|
555 |
+
)
|
556 |
+
response = await call_next(request)
|
557 |
+
return response
|
558 |
+
|
559 |
# FastAPI Endpoints
|
560 |
|
561 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
|
|
|
610 |
yield f"data: {json.dumps(response_chunk)}\n\n"
|
611 |
|
612 |
# After all chunks are sent, send the final message with finish_reason
|
613 |
+
prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
|
614 |
completion_tokens = len(assistant_content.split())
|
615 |
total_tokens = prompt_tokens + completion_tokens
|
616 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
|
|
695 |
logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
|
696 |
raise HTTPException(status_code=500, detail=str(e))
|
697 |
|
698 |
+
# Endpoint: POST /v1/tokenizer
|
699 |
@app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
|
700 |
async def tokenizer(request: TokenizerRequest, req: Request):
|
701 |
client_ip = req.client.host
|
|
|
704 |
logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
|
705 |
return {"text": text, "tokens": token_count}
|
706 |
|
707 |
+
# Endpoint: GET /v1/models
|
708 |
@app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
|
709 |
async def get_models(req: Request):
|
710 |
client_ip = req.client.host
|
711 |
logger.info(f"Fetching available models from IP: {client_ip}")
|
712 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
713 |
|
714 |
+
# Endpoint: GET /v1/models/{model}/status
|
715 |
@app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
|
716 |
async def model_status(model: str, req: Request):
|
717 |
client_ip = req.client.host
|
|
|
725 |
logger.warning(f"Model not found: {model} from IP: {client_ip}")
|
726 |
raise HTTPException(status_code=404, detail="Model not found")
|
727 |
|
728 |
+
# Endpoint: GET /v1/health
|
729 |
@app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
|
730 |
async def health_check(req: Request):
|
731 |
client_ip = req.client.host
|
732 |
logger.info(f"Health check requested from IP: {client_ip}")
|
733 |
return {"status": "ok"}
|
734 |
|
735 |
+
# Endpoint: GET /v1/chat/completions (GET method)
|
736 |
@app.get("/v1/chat/completions")
|
737 |
async def chat_completions_get(req: Request):
|
738 |
client_ip = req.client.host
|
739 |
logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
|
740 |
return RedirectResponse(url='about:blank')
|
741 |
|
742 |
+
# Custom exception handler to match OpenAI's error format
|
743 |
@app.exception_handler(HTTPException)
|
744 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
745 |
client_ip = request.client.host
|