Spaces:
Running
Running
File size: 6,013 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import json
from threading import Lock
from typing import (
Optional,
Union,
Iterator,
Dict,
Any,
AsyncIterator,
)
import anyio
from anyio.streams.memory import MemoryObjectSendStream
from fastapi import Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from loguru import logger
from pydantic import BaseModel
from starlette.concurrency import iterate_in_threadpool
from api.config import SETTINGS
from api.utils.compat import model_json, model_dump
from api.utils.constants import ErrorCode
from api.utils.protocol import (
ChatCompletionCreateParams,
CompletionCreateParams,
ErrorResponse,
)
llama_outer_lock = Lock()
llama_inner_lock = Lock()
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)),
):
if not SETTINGS.api_keys:
# api_keys not set; allow all
return None
if auth is None or (token := auth.credentials) not in SETTINGS.api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
def create_error_response(code: int, message: str) -> JSONResponse:
return JSONResponse(model_dump(ErrorResponse(message=message, code=code)), status_code=500)
async def handle_request(
request: Union[CompletionCreateParams, ChatCompletionCreateParams],
stop: Dict[str, Any] = None,
chat: bool = True,
) -> Union[Union[CompletionCreateParams, ChatCompletionCreateParams], JSONResponse]:
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
# stop settings
_stop, _stop_token_ids = [], []
if stop is not None:
_stop_token_ids = stop.get("token_ids", [])
_stop = stop.get("strings", [])
request.stop = request.stop or []
if isinstance(request.stop, str):
request.stop = [request.stop]
if chat and ("qwen" in SETTINGS.model_name.lower() and request.functions):
request.stop.append("Observation:")
request.stop = list(set(_stop + request.stop))
request.stop_token_ids = request.stop_token_ids or []
request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids))
request.top_p = max(request.top_p, 1e-5)
if request.temperature <= 1e-5:
request.top_p = 1.0
return request
def check_requests(request: Union[CompletionCreateParams, ChatCompletionCreateParams]) -> Optional[JSONResponse]:
# Check all params
if request.max_tokens is not None and request.max_tokens <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.n} is less than the minimum of 1 - 'n'",
)
if request.temperature is not None and request.temperature < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
)
if request.temperature is not None and request.temperature > 2:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
)
if request.top_p is not None and request.top_p < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
)
if request.top_p is not None and request.top_p > 1:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.stop is None or isinstance(request.stop, (str, list)):
return None
else:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream,
iterator: Union[Iterator, AsyncIterator],
):
async with inner_send_chan:
try:
if SETTINGS.engine not in ["vllm", "tgi"]:
async for chunk in iterate_in_threadpool(iterator):
if isinstance(chunk, BaseModel):
chunk = model_json(chunk)
elif isinstance(chunk, dict):
chunk = json.dumps(chunk, ensure_ascii=False)
await inner_send_chan.send(dict(data=chunk))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if SETTINGS.interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
else:
async for chunk in iterator:
chunk = model_json(chunk)
await inner_send_chan.send(dict(data=chunk))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
logger.info("disconnected")
with anyio.move_on_after(1, shield=True):
logger.info(f"Disconnected from client (via refresh/close) {request.client}")
raise e
|