Niansuh commited on
Commit
3cfd9e7
·
verified ·
1 Parent(s): db33061

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +93 -11
main.py CHANGED
@@ -27,11 +27,18 @@ logger = logging.getLogger(__name__)
27
  # Load environment variables
28
  API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
29
  RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
 
30
 
31
  if not API_KEYS or API_KEYS == ['']:
32
  logger.error("No API keys found. Please set the API_KEYS environment variable.")
33
  raise Exception("API_KEYS environment variable not set.")
34
 
 
 
 
 
 
 
35
  # Simple in-memory rate limiter
36
  rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
37
  ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
@@ -122,6 +129,10 @@ class Blackbox:
122
  'Niansuh',
123
  ]
124
 
 
 
 
 
125
  agentMode = {
126
  'ImageGeneration': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
127
  'Niansuh': {'mode': True, 'id': "NiansuhAIk1HgESy", 'name': "Niansuh"},
@@ -194,12 +205,12 @@ class Blackbox:
194
  def get_model(cls, model: str) -> str:
195
  if model in cls.models:
196
  return model
197
- elif model in cls.userSelectedModel:
198
  return model
199
- elif model in cls.model_aliases:
200
  return cls.model_aliases[model]
201
  else:
202
- return cls.default_model
203
 
204
  @classmethod
205
  async def create_async_generator(
@@ -213,6 +224,10 @@ class Blackbox:
213
  **kwargs
214
  ) -> AsyncGenerator[Any, None]:
215
  model = cls.get_model(model)
 
 
 
 
216
  logger.info(f"Selected model: {model}")
217
 
218
  if not cls.working or model not in cls.models:
@@ -477,23 +492,23 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
477
  logger.exception("An unexpected error occurred while processing the chat completions request.")
478
  raise HTTPException(status_code=500, detail=str(e))
479
 
480
- @app.get("/v1/models", dependencies=[Depends(rate_limiter)])
481
- async def get_models(api_key: str = Depends(get_api_key)):
482
- logger.info(f"Fetching available models for API key: {api_key}")
483
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
484
 
485
  # Additional endpoints for better functionality
486
  @app.get("/v1/health", dependencies=[Depends(rate_limiter)])
487
- async def health_check(api_key: str = Depends(get_api_key)):
488
  logger.info(f"Health check requested by API key: {api_key}")
489
  return {"status": "ok"}
490
 
491
- @app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter)])
492
- async def model_status(model: str, api_key: str = Depends(get_api_key)):
493
- logger.info(f"Model status requested for '{model}' by API key: {api_key}")
494
  if model in Blackbox.models:
495
  return {"model": model, "status": "available"}
496
- elif model in Blackbox.model_aliases:
497
  actual_model = Blackbox.model_aliases[model]
498
  return {"model": actual_model, "status": "available via alias"}
499
  else:
@@ -515,6 +530,73 @@ async def http_exception_handler(request: Request, exc: HTTPException):
515
  },
516
  )
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  if __name__ == "__main__":
519
  import uvicorn
520
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
27
  # Load environment variables
28
  API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
29
  RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
30
+ AVAILABLE_MODELS = os.getenv('AVAILABLE_MODELS', '') # Comma-separated available models
31
 
32
  if not API_KEYS or API_KEYS == ['']:
33
  logger.error("No API keys found. Please set the API_KEYS environment variable.")
34
  raise Exception("API_KEYS environment variable not set.")
35
 
36
+ # Process available models
37
+ if AVAILABLE_MODELS:
38
+ AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
39
+ else:
40
+ AVAILABLE_MODELS = [] # If empty, all models are available
41
+
42
  # Simple in-memory rate limiter
43
  rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
44
  ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
 
129
  'Niansuh',
130
  ]
131
 
132
+ # Filter models based on AVAILABLE_MODELS
133
+ if AVAILABLE_MODELS:
134
+ models = [model for model in models if model in AVAILABLE_MODELS]
135
+
136
  agentMode = {
137
  'ImageGeneration': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
138
  'Niansuh': {'mode': True, 'id': "NiansuhAIk1HgESy", 'name': "Niansuh"},
 
205
  def get_model(cls, model: str) -> str:
206
  if model in cls.models:
207
  return model
208
+ elif model in cls.userSelectedModel and cls.userSelectedModel[model] in cls.models:
209
  return model
210
+ elif model in cls.model_aliases and cls.model_aliases[model] in cls.models:
211
  return cls.model_aliases[model]
212
  else:
213
+ return cls.default_model if cls.default_model in cls.models else None
214
 
215
  @classmethod
216
  async def create_async_generator(
 
224
  **kwargs
225
  ) -> AsyncGenerator[Any, None]:
226
  model = cls.get_model(model)
227
+ if model is None:
228
+ logger.error(f"Model {model} is not available.")
229
+ raise ModelNotWorkingException(model)
230
+
231
  logger.info(f"Selected model: {model}")
232
 
233
  if not cls.working or model not in cls.models:
 
492
  logger.exception("An unexpected error occurred while processing the chat completions request.")
493
  raise HTTPException(status_code=500, detail=str(e))
494
 
495
+ @app.get("/v1/models")
496
+ async def get_models():
497
+ logger.info("Fetching available models")
498
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
499
 
500
  # Additional endpoints for better functionality
501
  @app.get("/v1/health", dependencies=[Depends(rate_limiter)])
502
+ async def health_check(req: Request, api_key: str = Depends(get_api_key)):
503
  logger.info(f"Health check requested by API key: {api_key}")
504
  return {"status": "ok"}
505
 
506
+ @app.get("/v1/models/{model}/status")
507
+ async def model_status(model: str):
508
+ logger.info(f"Model status requested for '{model}'")
509
  if model in Blackbox.models:
510
  return {"model": model, "status": "available"}
511
+ elif model in Blackbox.model_aliases and Blackbox.model_aliases[model] in Blackbox.models:
512
  actual_model = Blackbox.model_aliases[model]
513
  return {"model": actual_model, "status": "available via alias"}
514
  else:
 
530
  },
531
  )
532
 
533
+ # New endpoint: /v1/tokenizer to calculate token counts
534
+ class TokenizerRequest(BaseModel):
535
+ text: str
536
+
537
+ @app.post("/v1/tokenizer")
538
+ async def tokenizer(request: TokenizerRequest):
539
+ text = request.text
540
+ token_count = len(text.split())
541
+ return {"text": text, "tokens": token_count}
542
+
543
+ # New endpoint: /v1/completions to support text completions
544
+ class CompletionRequest(BaseModel):
545
+ model: str
546
+ prompt: str
547
+ max_tokens: Optional[int] = 16
548
+ temperature: Optional[float] = 1.0
549
+ top_p: Optional[float] = 1.0
550
+ n: Optional[int] = 1
551
+ stream: Optional[bool] = False
552
+ stop: Optional[Union[str, List[str]]] = None
553
+ logprobs: Optional[int] = None
554
+ echo: Optional[bool] = False
555
+ presence_penalty: Optional[float] = 0.0
556
+ frequency_penalty: Optional[float] = 0.0
557
+ best_of: Optional[int] = 1
558
+ logit_bias: Optional[Dict[str, float]] = None
559
+ user: Optional[str] = None
560
+
561
+ @app.post("/v1/completions", dependencies=[Depends(rate_limiter)])
562
+ async def completions(request: CompletionRequest, req: Request, api_key: str = Depends(get_api_key)):
563
+ logger.info(f"Received completion request from API key: {api_key} | Model: {request.model}")
564
+
565
+ try:
566
+ # Validate that the requested model is available
567
+ if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
568
+ logger.warning(f"Attempt to use unavailable model: {request.model}")
569
+ raise HTTPException(status_code=400, detail="Requested model is not available.")
570
+
571
+ # Simulate a simple completion by echoing the prompt
572
+ completion_text = f"{request.prompt} [Completed by {request.model}]"
573
+
574
+ return {
575
+ "id": f"cmpl-{uuid.uuid4()}",
576
+ "object": "text_completion",
577
+ "created": int(datetime.now().timestamp()),
578
+ "model": request.model,
579
+ "choices": [
580
+ {
581
+ "text": completion_text,
582
+ "index": 0,
583
+ "logprobs": None,
584
+ "finish_reason": "length"
585
+ }
586
+ ],
587
+ "usage": {
588
+ "prompt_tokens": len(request.prompt.split()),
589
+ "completion_tokens": len(completion_text.split()),
590
+ "total_tokens": len(request.prompt.split()) + len(completion_text.split())
591
+ }
592
+ }
593
+ except HTTPException as he:
594
+ logger.warning(f"HTTPException: {he.detail}")
595
+ raise he
596
+ except Exception as e:
597
+ logger.exception("An unexpected error occurred while processing the completions request.")
598
+ raise HTTPException(status_code=500, detail=str(e))
599
+
600
  if __name__ == "__main__":
601
  import uvicorn
602
  uvicorn.run(app, host="0.0.0.0", port=8000)