import time import uuid from functools import partial from typing import ( List, Dict, Any, AsyncIterator, Optional, ) import anyio from fastapi import APIRouter, Depends from fastapi import HTTPException, Request from loguru import logger from openai.types.completion import Completion from openai.types.completion_choice import CompletionChoice, Logprobs from openai.types.completion_usage import CompletionUsage from sse_starlette import EventSourceResponse from vllm.outputs import RequestOutput from api.core.vllm_engine import VllmEngine from api.models import GENERATE_ENGINE from api.utils.compat import model_dump from api.utils.protocol import CompletionCreateParams from api.utils.request import ( handle_request, get_event_publisher, check_api_key ) completion_router = APIRouter() def get_engine(): yield GENERATE_ENGINE @completion_router.post("/completions", dependencies=[Depends(check_api_key)]) async def create_completion( request: CompletionCreateParams, raw_request: Request, engine: VllmEngine = Depends(get_engine), ): """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create for the API specification. This API mimics the OpenAI Completion API. """ if request.echo: # We do not support echo since the vLLM engine does not # currently support getting the logprobs of prompt tokens. raise HTTPException(status_code=400, detail="echo is not currently supported") if request.suffix: # The language models we currently support do not support suffix. raise HTTPException(status_code=400, detail="suffix is not currently supported") request.max_tokens = request.max_tokens or 128 request = await handle_request(request, engine.prompt_adapter.stop, chat=False) if isinstance(request.prompt, list): request.prompt = request.prompt[0] params = model_dump(request, exclude={"prompt"}) params.update(dict(prompt_or_messages=request.prompt)) logger.debug(f"==== request ====\n{params}") request_id: str = f"cmpl-{str(uuid.uuid4())}" generator = engine.generate(params, request_id) if request.stream: iterator = create_completion_stream(generator, params, request_id, engine.tokenizer) send_chan, recv_chan = anyio.create_memory_object_stream(10) return EventSourceResponse( recv_chan, data_sender_callable=partial( get_event_publisher, request=raw_request, inner_send_chan=send_chan, iterator=iterator, ), ) else: # Non-streaming response final_res: RequestOutput = None async for res in generator: if raw_request is not None and await raw_request.is_disconnected(): await engine.model.abort(request_id) return final_res = res assert final_res is not None choices = [] for output in final_res.outputs: output.text = output.text.replace("�", "") logprobs = None if params.get("logprobs", None) is not None: logprobs = create_logprobs(engine.tokenizer, output.token_ids, output.logprobs) choice = CompletionChoice( index=output.index, text=output.text, finish_reason=output.finish_reason, logprobs=logprobs, ) choices.append(choice) num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) usage = CompletionUsage( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) return Completion( id=request_id, choices=choices, created=int(time.time()), model=params.get("model", "llm"), object="text_completion", usage=usage, ) def create_logprobs( tokenizer, token_ids: List[int], top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, num_output_top_logprobs: Optional[int] = None, initial_text_offset: int = 0, ) -> Logprobs: logprobs = Logprobs(text_offset=[], token_logprobs=[], tokens=[], top_logprobs=None) last_token_len = 0 if num_output_top_logprobs: logprobs.top_logprobs = [] for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is not None: token_logprob = step_top_logprobs[token_id] else: token_logprob = None token = tokenizer.convert_ids_to_tokens(token_id) logprobs.tokens.append(token) logprobs.token_logprobs.append(token_logprob) if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) last_token_len = len(token) if num_output_top_logprobs: logprobs.top_logprobs.append( { tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items() } if step_top_logprobs else None ) return logprobs async def create_completion_stream( generator: AsyncIterator, params: Dict[str, Any], request_id: str, tokenizer, ) -> AsyncIterator: n = params.get("n", 1) previous_texts = [""] * n previous_num_tokens = [0] * n async for res in generator: res: RequestOutput for output in res.outputs: i = output.index output.text = output.text.replace("�", "") delta_text = output.text[len(previous_texts[i]):] if params.get("logprobs") is not None: logprobs = create_logprobs( tokenizer, output.token_ids[previous_num_tokens[i]:], output.logprobs[previous_num_tokens[i]:], len(previous_texts[i]) ) else: logprobs = None previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) choice = CompletionChoice( index=i, text=delta_text, finish_reason="stop", logprobs=logprobs, ) yield Completion( id=request_id, choices=[choice], created=int(time.time()), model=params.get("model", "llm"), object="text_completion", ) if output.finish_reason is not None: if params.get("logprobs") is not None: logprobs = Logprobs( text_offset=[], token_logprobs=[], tokens=[], top_logprobs=[] ) else: logprobs = None choice = CompletionChoice( index=i, text=delta_text, finish_reason="stop", logprobs=logprobs, ) yield Completion( id=request_id, choices=[choice], created=int(time.time()), model=params.get("model", "llm"), object="text_completion", )