Spaces:
Running
Running
File size: 7,670 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import time
import uuid
from functools import partial
from typing import (
List,
Dict,
Any,
AsyncIterator,
Optional,
)
import anyio
from fastapi import APIRouter, Depends
from fastapi import HTTPException, Request
from loguru import logger
from openai.types.completion import Completion
from openai.types.completion_choice import CompletionChoice, Logprobs
from openai.types.completion_usage import CompletionUsage
from sse_starlette import EventSourceResponse
from vllm.outputs import RequestOutput
from api.core.vllm_engine import VllmEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump
from api.utils.protocol import CompletionCreateParams
from api.utils.request import (
handle_request,
get_event_publisher,
check_api_key
)
completion_router = APIRouter()
def get_engine():
yield GENERATE_ENGINE
@completion_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_completion(
request: CompletionCreateParams,
raw_request: Request,
engine: VllmEngine = Depends(get_engine),
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
"""
if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
raise HTTPException(status_code=400, detail="echo is not currently supported")
if request.suffix:
# The language models we currently support do not support suffix.
raise HTTPException(status_code=400, detail="suffix is not currently supported")
request.max_tokens = request.max_tokens or 128
request = await handle_request(request, engine.prompt_adapter.stop, chat=False)
if isinstance(request.prompt, list):
request.prompt = request.prompt[0]
params = model_dump(request, exclude={"prompt"})
params.update(dict(prompt_or_messages=request.prompt))
logger.debug(f"==== request ====\n{params}")
request_id: str = f"cmpl-{str(uuid.uuid4())}"
generator = engine.generate(params, request_id)
if request.stream:
iterator = create_completion_stream(generator, params, request_id, engine.tokenizer)
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator,
),
)
else:
# Non-streaming response
final_res: RequestOutput = None
async for res in generator:
if raw_request is not None and await raw_request.is_disconnected():
await engine.model.abort(request_id)
return
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
output.text = output.text.replace("�", "")
logprobs = None
if params.get("logprobs", None) is not None:
logprobs = create_logprobs(engine.tokenizer, output.token_ids, output.logprobs)
choice = CompletionChoice(
index=output.index,
text=output.text,
finish_reason=output.finish_reason,
logprobs=logprobs,
)
choices.append(choice)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs)
usage = CompletionUsage(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return Completion(
id=request_id,
choices=choices,
created=int(time.time()),
model=params.get("model", "llm"),
object="text_completion",
usage=usage,
)
def create_logprobs(
tokenizer,
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> Logprobs:
logprobs = Logprobs(text_offset=[], token_logprobs=[], tokens=[], top_logprobs=None)
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append(
{
tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
}
if step_top_logprobs else None
)
return logprobs
async def create_completion_stream(
generator: AsyncIterator, params: Dict[str, Any], request_id: str, tokenizer,
) -> AsyncIterator:
n = params.get("n", 1)
previous_texts = [""] * n
previous_num_tokens = [0] * n
async for res in generator:
res: RequestOutput
for output in res.outputs:
i = output.index
output.text = output.text.replace("�", "")
delta_text = output.text[len(previous_texts[i]):]
if params.get("logprobs") is not None:
logprobs = create_logprobs(
tokenizer,
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i])
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
choice = CompletionChoice(
index=i,
text=delta_text,
finish_reason="stop",
logprobs=logprobs,
)
yield Completion(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="text_completion",
)
if output.finish_reason is not None:
if params.get("logprobs") is not None:
logprobs = Logprobs(
text_offset=[], token_logprobs=[], tokens=[], top_logprobs=[]
)
else:
logprobs = None
choice = CompletionChoice(
index=i,
text=delta_text,
finish_reason="stop",
logprobs=logprobs,
)
yield Completion(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="text_completion",
)
|