Niansuh commited on
Commit
bb03ef5
·
verified ·
1 Parent(s): cf13b1c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -39
main.py CHANGED
@@ -473,38 +473,6 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
473
  except Exception as e:
474
  yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"
475
 
476
- # FastAPI app setup
477
- app = FastAPI()
478
-
479
- # Add the cleanup task when the app starts
480
- @app.on_event("startup")
481
- async def startup_event():
482
- asyncio.create_task(cleanup_rate_limit_stores())
483
- logger.info("Started rate limit store cleanup task.")
484
-
485
- # Middleware to enhance security and enforce Content-Type for specific endpoints
486
- @app.middleware("http")
487
- async def security_middleware(request: Request, call_next):
488
- client_ip = request.client.host
489
- # Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
490
- if request.method == "POST" and request.url.path == "/v1/chat/completions":
491
- content_type = request.headers.get("Content-Type")
492
- if content_type != "application/json":
493
- logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
494
- return JSONResponse(
495
- status_code=400,
496
- content={
497
- "error": {
498
- "message": "Content-Type must be application/json",
499
- "type": "invalid_request_error",
500
- "param": None,
501
- "code": None
502
- }
503
- },
504
- )
505
- response = await call_next(request)
506
- return response
507
-
508
  # Request Models
509
  class Message(BaseModel):
510
  role: str
@@ -556,6 +524,38 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
556
  "usage": None, # To be filled in non-streaming responses
557
  }
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  # FastAPI Endpoints
560
 
561
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
@@ -610,7 +610,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
610
  yield f"data: {json.dumps(response_chunk)}\n\n"
611
 
612
  # After all chunks are sent, send the final message with finish_reason
613
- prompt_tokens = sum(len(msg['content'].split()) for msg in request.messages)
614
  completion_tokens = len(assistant_content.split())
615
  total_tokens = prompt_tokens + completion_tokens
616
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
@@ -695,7 +695,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
695
  logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
696
  raise HTTPException(status_code=500, detail=str(e))
697
 
698
- # Tokenizer Endpoint
699
  @app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
700
  async def tokenizer(request: TokenizerRequest, req: Request):
701
  client_ip = req.client.host
@@ -704,14 +704,14 @@ async def tokenizer(request: TokenizerRequest, req: Request):
704
  logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
705
  return {"text": text, "tokens": token_count}
706
 
707
- # Get Models Endpoint
708
  @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
709
  async def get_models(req: Request):
710
  client_ip = req.client.host
711
  logger.info(f"Fetching available models from IP: {client_ip}")
712
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
713
 
714
- # Model Status Endpoint
715
  @app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
716
  async def model_status(model: str, req: Request):
717
  client_ip = req.client.host
@@ -725,21 +725,21 @@ async def model_status(model: str, req: Request):
725
  logger.warning(f"Model not found: {model} from IP: {client_ip}")
726
  raise HTTPException(status_code=404, detail="Model not found")
727
 
728
- # Health Check Endpoint
729
  @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
730
  async def health_check(req: Request):
731
  client_ip = req.client.host
732
  logger.info(f"Health check requested from IP: {client_ip}")
733
  return {"status": "ok"}
734
 
735
- # Redirect GET requests to /v1/chat/completions to 'about:blank'
736
  @app.get("/v1/chat/completions")
737
  async def chat_completions_get(req: Request):
738
  client_ip = req.client.host
739
  logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
740
  return RedirectResponse(url='about:blank')
741
 
742
- # Custom Exception Handler
743
  @app.exception_handler(HTTPException)
744
  async def http_exception_handler(request: Request, exc: HTTPException):
745
  client_ip = request.client.host
 
473
  except Exception as e:
474
  yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"
475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  # Request Models
477
  class Message(BaseModel):
478
  role: str
 
524
  "usage": None, # To be filled in non-streaming responses
525
  }
526
 
527
+ # Initialize FastAPI app
528
+ app = FastAPI()
529
+
530
+ # Add the cleanup task when the app starts
531
+ @app.on_event("startup")
532
+ async def startup_event():
533
+ asyncio.create_task(cleanup_rate_limit_stores())
534
+ logger.info("Started rate limit store cleanup task.")
535
+
536
+ # Middleware to enhance security and enforce Content-Type for specific endpoints
537
+ @app.middleware("http")
538
+ async def security_middleware(request: Request, call_next):
539
+ client_ip = request.client.host
540
+ # Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
541
+ if request.method == "POST" and request.url.path == "/v1/chat/completions":
542
+ content_type = request.headers.get("Content-Type")
543
+ if content_type != "application/json":
544
+ logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
545
+ return JSONResponse(
546
+ status_code=400,
547
+ content={
548
+ "error": {
549
+ "message": "Content-Type must be application/json",
550
+ "type": "invalid_request_error",
551
+ "param": None,
552
+ "code": None
553
+ }
554
+ },
555
+ )
556
+ response = await call_next(request)
557
+ return response
558
+
559
  # FastAPI Endpoints
560
 
561
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
 
610
  yield f"data: {json.dumps(response_chunk)}\n\n"
611
 
612
  # After all chunks are sent, send the final message with finish_reason
613
+ prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
614
  completion_tokens = len(assistant_content.split())
615
  total_tokens = prompt_tokens + completion_tokens
616
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
 
695
  logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
696
  raise HTTPException(status_code=500, detail=str(e))
697
 
698
+ # Endpoint: POST /v1/tokenizer
699
  @app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
700
  async def tokenizer(request: TokenizerRequest, req: Request):
701
  client_ip = req.client.host
 
704
  logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
705
  return {"text": text, "tokens": token_count}
706
 
707
+ # Endpoint: GET /v1/models
708
  @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
709
  async def get_models(req: Request):
710
  client_ip = req.client.host
711
  logger.info(f"Fetching available models from IP: {client_ip}")
712
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
713
 
714
+ # Endpoint: GET /v1/models/{model}/status
715
  @app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
716
  async def model_status(model: str, req: Request):
717
  client_ip = req.client.host
 
725
  logger.warning(f"Model not found: {model} from IP: {client_ip}")
726
  raise HTTPException(status_code=404, detail="Model not found")
727
 
728
+ # Endpoint: GET /v1/health
729
  @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
730
  async def health_check(req: Request):
731
  client_ip = req.client.host
732
  logger.info(f"Health check requested from IP: {client_ip}")
733
  return {"status": "ok"}
734
 
735
+ # Endpoint: GET /v1/chat/completions (GET method)
736
  @app.get("/v1/chat/completions")
737
  async def chat_completions_get(req: Request):
738
  client_ip = req.client.host
739
  logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
740
  return RedirectResponse(url='about:blank')
741
 
742
+ # Custom exception handler to match OpenAI's error format
743
  @app.exception_handler(HTTPException)
744
  async def http_exception_handler(request: Request, exc: HTTPException):
745
  client_ip = request.client.host