File size: 6,013 Bytes
ca56e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import json
from threading import Lock
from typing import (
    Optional,
    Union,
    Iterator,
    Dict,
    Any,
    AsyncIterator,
)

import anyio
from anyio.streams.memory import MemoryObjectSendStream
from fastapi import Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from loguru import logger
from pydantic import BaseModel
from starlette.concurrency import iterate_in_threadpool

from api.config import SETTINGS
from api.utils.compat import model_json, model_dump
from api.utils.constants import ErrorCode
from api.utils.protocol import (
    ChatCompletionCreateParams,
    CompletionCreateParams,
    ErrorResponse,
)

llama_outer_lock = Lock()
llama_inner_lock = Lock()


async def check_api_key(
    auth: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)),
):
    if not SETTINGS.api_keys:
        # api_keys not set; allow all
        return None
    if auth is None or (token := auth.credentials) not in SETTINGS.api_keys:
        raise HTTPException(
            status_code=401,
            detail={
                "error": {
                    "message": "",
                    "type": "invalid_request_error",
                    "param": None,
                    "code": "invalid_api_key",
                }
            },
        )
    return token


def create_error_response(code: int, message: str) -> JSONResponse:
    return JSONResponse(model_dump(ErrorResponse(message=message, code=code)), status_code=500)


async def handle_request(
    request: Union[CompletionCreateParams, ChatCompletionCreateParams],
    stop: Dict[str, Any] = None,
    chat: bool = True,
) -> Union[Union[CompletionCreateParams, ChatCompletionCreateParams], JSONResponse]:
    error_check_ret = check_requests(request)
    if error_check_ret is not None:
        return error_check_ret

    # stop settings
    _stop, _stop_token_ids = [], []
    if stop is not None:
        _stop_token_ids = stop.get("token_ids", [])
        _stop = stop.get("strings", [])

    request.stop = request.stop or []
    if isinstance(request.stop, str):
        request.stop = [request.stop]

    if chat and ("qwen" in SETTINGS.model_name.lower() and request.functions):
        request.stop.append("Observation:")

    request.stop = list(set(_stop + request.stop))
    request.stop_token_ids = request.stop_token_ids or []
    request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids))

    request.top_p = max(request.top_p, 1e-5)
    if request.temperature <= 1e-5:
        request.top_p = 1.0

    return request


def check_requests(request: Union[CompletionCreateParams, ChatCompletionCreateParams]) -> Optional[JSONResponse]:
    # Check all params
    if request.max_tokens is not None and request.max_tokens <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
        )
    if request.n is not None and request.n <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.n} is less than the minimum of 1 - 'n'",
        )
    if request.temperature is not None and request.temperature < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is less than the minimum of 0 - 'temperature'",
        )
    if request.temperature is not None and request.temperature > 2:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
        )
    if request.top_p is not None and request.top_p < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is less than the minimum of 0 - 'top_p'",
        )
    if request.top_p is not None and request.top_p > 1:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
        )
    if request.stop is None or isinstance(request.stop, (str, list)):
        return None
    else:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.stop} is not valid under any of the given schemas - 'stop'",
        )


async def get_event_publisher(
    request: Request,
    inner_send_chan: MemoryObjectSendStream,
    iterator: Union[Iterator, AsyncIterator],
):
    async with inner_send_chan:
        try:
            if SETTINGS.engine not in ["vllm", "tgi"]:
                async for chunk in iterate_in_threadpool(iterator):
                    if isinstance(chunk, BaseModel):
                        chunk = model_json(chunk)
                    elif isinstance(chunk, dict):
                        chunk = json.dumps(chunk, ensure_ascii=False)

                    await inner_send_chan.send(dict(data=chunk))

                    if await request.is_disconnected():
                        raise anyio.get_cancelled_exc_class()()

                    if SETTINGS.interrupt_requests and llama_outer_lock.locked():
                        await inner_send_chan.send(dict(data="[DONE]"))
                        raise anyio.get_cancelled_exc_class()()
            else:
                async for chunk in iterator:
                    chunk = model_json(chunk)
                    await inner_send_chan.send(dict(data=chunk))
                    if await request.is_disconnected():
                        raise anyio.get_cancelled_exc_class()()
            await inner_send_chan.send(dict(data="[DONE]"))
        except anyio.get_cancelled_exc_class() as e:
            logger.info("disconnected")
            with anyio.move_on_after(1, shield=True):
                logger.info(f"Disconnected from client (via refresh/close) {request.client}")
                raise e