import json from threading import Lock from typing import ( Optional, Union, Iterator, Dict, Any, AsyncIterator, ) import anyio from anyio.streams.memory import MemoryObjectSendStream from fastapi import Depends, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from loguru import logger from pydantic import BaseModel from starlette.concurrency import iterate_in_threadpool from api.config import SETTINGS from api.utils.compat import model_json, model_dump from api.utils.constants import ErrorCode from api.utils.protocol import ( ChatCompletionCreateParams, CompletionCreateParams, ErrorResponse, ) llama_outer_lock = Lock() llama_inner_lock = Lock() async def check_api_key( auth: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)), ): if not SETTINGS.api_keys: # api_keys not set; allow all return None if auth is None or (token := auth.credentials) not in SETTINGS.api_keys: raise HTTPException( status_code=401, detail={ "error": { "message": "", "type": "invalid_request_error", "param": None, "code": "invalid_api_key", } }, ) return token def create_error_response(code: int, message: str) -> JSONResponse: return JSONResponse(model_dump(ErrorResponse(message=message, code=code)), status_code=500) async def handle_request( request: Union[CompletionCreateParams, ChatCompletionCreateParams], stop: Dict[str, Any] = None, chat: bool = True, ) -> Union[Union[CompletionCreateParams, ChatCompletionCreateParams], JSONResponse]: error_check_ret = check_requests(request) if error_check_ret is not None: return error_check_ret # stop settings _stop, _stop_token_ids = [], [] if stop is not None: _stop_token_ids = stop.get("token_ids", []) _stop = stop.get("strings", []) request.stop = request.stop or [] if isinstance(request.stop, str): request.stop = [request.stop] if chat and ("qwen" in SETTINGS.model_name.lower() and request.functions): request.stop.append("Observation:") request.stop = list(set(_stop + request.stop)) request.stop_token_ids = request.stop_token_ids or [] request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids)) request.top_p = max(request.top_p, 1e-5) if request.temperature <= 1e-5: request.top_p = 1.0 return request def check_requests(request: Union[CompletionCreateParams, ChatCompletionCreateParams]) -> Optional[JSONResponse]: # Check all params if request.max_tokens is not None and request.max_tokens <= 0: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", ) if request.n is not None and request.n <= 0: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.n} is less than the minimum of 1 - 'n'", ) if request.temperature is not None and request.temperature < 0: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.temperature} is less than the minimum of 0 - 'temperature'", ) if request.temperature is not None and request.temperature > 2: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.temperature} is greater than the maximum of 2 - 'temperature'", ) if request.top_p is not None and request.top_p < 0: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.top_p} is less than the minimum of 0 - 'top_p'", ) if request.top_p is not None and request.top_p > 1: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.top_p} is greater than the maximum of 1 - 'temperature'", ) if request.stop is None or isinstance(request.stop, (str, list)): return None else: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.stop} is not valid under any of the given schemas - 'stop'", ) async def get_event_publisher( request: Request, inner_send_chan: MemoryObjectSendStream, iterator: Union[Iterator, AsyncIterator], ): async with inner_send_chan: try: if SETTINGS.engine not in ["vllm", "tgi"]: async for chunk in iterate_in_threadpool(iterator): if isinstance(chunk, BaseModel): chunk = model_json(chunk) elif isinstance(chunk, dict): chunk = json.dumps(chunk, ensure_ascii=False) await inner_send_chan.send(dict(data=chunk)) if await request.is_disconnected(): raise anyio.get_cancelled_exc_class()() if SETTINGS.interrupt_requests and llama_outer_lock.locked(): await inner_send_chan.send(dict(data="[DONE]")) raise anyio.get_cancelled_exc_class()() else: async for chunk in iterator: chunk = model_json(chunk) await inner_send_chan.send(dict(data=chunk)) if await request.is_disconnected(): raise anyio.get_cancelled_exc_class()() await inner_send_chan.send(dict(data="[DONE]")) except anyio.get_cancelled_exc_class() as e: logger.info("disconnected") with anyio.move_on_after(1, shield=True): logger.info(f"Disconnected from client (via refresh/close) {request.client}") raise e