Spaces:
Running
Running
File size: 7,402 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import time
import traceback
import uuid
from functools import partial
from typing import (
Dict,
Any,
AsyncIterator,
)
import anyio
from fastapi import APIRouter, Depends
from fastapi import HTTPException, Request
from loguru import logger
from openai.types.chat import (
ChatCompletionMessage,
ChatCompletion,
ChatCompletionChunk,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from openai.types.completion_usage import CompletionUsage
from sse_starlette import EventSourceResponse
from vllm.outputs import RequestOutput
from api.core.vllm_engine import VllmEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump, model_parse
from api.utils.protocol import Role, ChatCompletionCreateParams
from api.utils.request import (
check_api_key,
handle_request,
get_event_publisher,
)
chat_router = APIRouter(prefix="/chat")
def get_engine():
yield GENERATE_ENGINE
@chat_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
request: ChatCompletionCreateParams,
raw_request: Request,
engine: VllmEngine = Depends(get_engine),
):
if (not request.messages) or request.messages[-1]["role"] == Role.ASSISTANT:
raise HTTPException(status_code=400, detail="Invalid request")
request = await handle_request(request, engine.prompt_adapter.stop)
request.max_tokens = request.max_tokens or 512
params = model_dump(request, exclude={"messages"})
params.update(dict(prompt_or_messages=request.messages, echo=False))
logger.debug(f"==== request ====\n{params}")
request_id: str = f"chatcmpl-{str(uuid.uuid4())}"
generator = engine.generate(params, request_id)
if request.stream:
iterator = create_chat_completion_stream(generator, params, request_id)
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator,
),
)
else:
# Non-streaming response
final_res: RequestOutput = None
async for res in generator:
if raw_request is not None and await raw_request.is_disconnected():
await engine.model.abort(request_id)
return
final_res = res
assert final_res is not None
choices = []
functions = params.get("functions", None)
tools = params.get("tools", None)
for output in final_res.outputs:
output.text = output.text.replace("�", "")
finish_reason = output.finish_reason
function_call = None
if functions or tools:
try:
res, function_call = engine.prompt_adapter.parse_assistant_response(
output.text, functions, tools,
)
output.text = res
except Exception as e:
traceback.print_exc()
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict) and "arguments" in function_call:
function_call = FunctionCall(**function_call)
message = ChatCompletionMessage(
role="assistant",
content=output.text,
function_call=function_call
)
finish_reason = "function_call"
elif isinstance(function_call, dict) and "function" in function_call:
finish_reason = "tool_calls"
tool_calls = [model_parse(ChatCompletionMessageToolCall, function_call)]
message = ChatCompletionMessage(
role="assistant",
content=output.text,
tool_calls=tool_calls,
)
else:
message = ChatCompletionMessage(role="assistant", content=output.text)
choices.append(
Choice(
index=output.index,
message=message,
finish_reason=finish_reason,
)
)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs)
usage = CompletionUsage(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return ChatCompletion(
id=request_id,
choices=choices,
created=int(time.time()),
model=request.model,
object="chat.completion",
usage=usage,
)
async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[str, Any], request_id: str) -> AsyncIterator:
n = params.get("n", 1)
for i in range(n):
# First chunk with role
choice = ChunkChoice(
index=i,
delta=ChoiceDelta(role="assistant", content=""),
finish_reason=None,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
previous_texts = [""] * n
previous_num_tokens = [0] * n
async for res in generator:
res: RequestOutput
for output in res.outputs:
i = output.index
output.text = output.text.replace("�", "")
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
choice = ChunkChoice(
index=i,
delta=ChoiceDelta(content=delta_text),
finish_reason=output.finish_reason,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
if output.finish_reason is not None:
choice = ChunkChoice(
index=i,
delta=ChoiceDelta(),
finish_reason="stop",
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
|