gordonchan's picture
Upload 41 files
ca56e6a verified
raw
history blame
7.67 kB
import time
import uuid
from functools import partial
from typing import (
List,
Dict,
Any,
AsyncIterator,
Optional,
)
import anyio
from fastapi import APIRouter, Depends
from fastapi import HTTPException, Request
from loguru import logger
from openai.types.completion import Completion
from openai.types.completion_choice import CompletionChoice, Logprobs
from openai.types.completion_usage import CompletionUsage
from sse_starlette import EventSourceResponse
from vllm.outputs import RequestOutput
from api.core.vllm_engine import VllmEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump
from api.utils.protocol import CompletionCreateParams
from api.utils.request import (
handle_request,
get_event_publisher,
check_api_key
)
completion_router = APIRouter()
def get_engine():
yield GENERATE_ENGINE
@completion_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_completion(
request: CompletionCreateParams,
raw_request: Request,
engine: VllmEngine = Depends(get_engine),
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
"""
if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
raise HTTPException(status_code=400, detail="echo is not currently supported")
if request.suffix:
# The language models we currently support do not support suffix.
raise HTTPException(status_code=400, detail="suffix is not currently supported")
request.max_tokens = request.max_tokens or 128
request = await handle_request(request, engine.prompt_adapter.stop, chat=False)
if isinstance(request.prompt, list):
request.prompt = request.prompt[0]
params = model_dump(request, exclude={"prompt"})
params.update(dict(prompt_or_messages=request.prompt))
logger.debug(f"==== request ====\n{params}")
request_id: str = f"cmpl-{str(uuid.uuid4())}"
generator = engine.generate(params, request_id)
if request.stream:
iterator = create_completion_stream(generator, params, request_id, engine.tokenizer)
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator,
),
)
else:
# Non-streaming response
final_res: RequestOutput = None
async for res in generator:
if raw_request is not None and await raw_request.is_disconnected():
await engine.model.abort(request_id)
return
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
output.text = output.text.replace("�", "")
logprobs = None
if params.get("logprobs", None) is not None:
logprobs = create_logprobs(engine.tokenizer, output.token_ids, output.logprobs)
choice = CompletionChoice(
index=output.index,
text=output.text,
finish_reason=output.finish_reason,
logprobs=logprobs,
)
choices.append(choice)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs)
usage = CompletionUsage(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return Completion(
id=request_id,
choices=choices,
created=int(time.time()),
model=params.get("model", "llm"),
object="text_completion",
usage=usage,
)
def create_logprobs(
tokenizer,
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> Logprobs:
logprobs = Logprobs(text_offset=[], token_logprobs=[], tokens=[], top_logprobs=None)
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append(
{
tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
}
if step_top_logprobs else None
)
return logprobs
async def create_completion_stream(
generator: AsyncIterator, params: Dict[str, Any], request_id: str, tokenizer,
) -> AsyncIterator:
n = params.get("n", 1)
previous_texts = [""] * n
previous_num_tokens = [0] * n
async for res in generator:
res: RequestOutput
for output in res.outputs:
i = output.index
output.text = output.text.replace("�", "")
delta_text = output.text[len(previous_texts[i]):]
if params.get("logprobs") is not None:
logprobs = create_logprobs(
tokenizer,
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i])
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
choice = CompletionChoice(
index=i,
text=delta_text,
finish_reason="stop",
logprobs=logprobs,
)
yield Completion(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="text_completion",
)
if output.finish_reason is not None:
if params.get("logprobs") is not None:
logprobs = Logprobs(
text_offset=[], token_logprobs=[], tokens=[], top_logprobs=[]
)
else:
logprobs = None
choice = CompletionChoice(
index=i,
text=delta_text,
finish_reason="stop",
logprobs=logprobs,
)
yield Completion(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="text_completion",
)