import asyncio from typing import ( Optional, List, Dict, Any, AsyncIterator, Union, ) from fastapi import HTTPException from loguru import logger from openai.types.chat import ChatCompletionMessageParam from transformers import PreTrainedTokenizer from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from api.adapter import get_prompt_adapter from api.generation import build_qwen_chat_input class VllmEngine: def __init__( self, model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer, model_name: str, prompt_name: Optional[str] = None, context_len: Optional[int] = -1, ): """ Initializes the VLLMEngine object. Args: model: The AsyncLLMEngine object. tokenizer: The PreTrainedTokenizer object. model_name: The name of the model. prompt_name: The name of the prompt (optional). context_len: The length of the context (optional, default=-1). """ self.model = model self.model_name = model_name.lower() self.tokenizer = tokenizer self.prompt_name = prompt_name.lower() if prompt_name is not None else None self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name) model_config = asyncio.run(self.model.get_model_config()) if "qwen" in self.model_name: self.max_model_len = context_len if context_len > 0 else 8192 else: self.max_model_len = model_config.max_model_len def apply_chat_template( self, messages: List[ChatCompletionMessageParam], max_tokens: Optional[int] = 256, functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, tools: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, List[int]]: """ Applies a chat template to the given messages and returns the processed output. Args: messages: A list of ChatCompletionMessageParam objects representing the chat messages. max_tokens: The maximum number of tokens in the output (optional, default=256). functions: A dictionary or list of dictionaries representing the functions to be applied (optional). tools: A list of dictionaries representing the tools to be used (optional). Returns: Union[str, List[int]]: The processed output as a string or a list of integers. """ if self.prompt_adapter.function_call_available: messages = self.prompt_adapter.postprocess_messages( messages, functions, tools, ) if functions or tools: logger.debug(f"==== Messages with tools ====\n{messages}") if "chatglm3" in self.model_name: query, role = messages[-1]["content"], messages[-1]["role"] return self.tokenizer.build_chat_input( query, history=messages[:-1], role=role )["input_ids"][0].tolist() elif "qwen" in self.model_name: return build_qwen_chat_input( self.tokenizer, messages, self.max_model_len, max_tokens, functions, tools, ) else: return self.prompt_adapter.apply_chat_template(messages) def convert_to_inputs( self, prompt: Optional[str] = None, token_ids: Optional[List[int]] = None, max_tokens: Optional[int] = 256, ) -> List[int]: max_input_tokens = self.max_model_len - max_tokens input_ids = token_ids or self.tokenizer(prompt).input_ids return input_ids[-max_input_tokens:] def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator: """ Generates text based on the given parameters and request ID. Args: params (Dict[str, Any]): A dictionary of parameters for text generation. request_id (str): The ID of the request. Yields: Any: The generated text. """ max_tokens = params.get("max_tokens", 256) prompt_or_messages = params.get("prompt_or_messages") if isinstance(prompt_or_messages, list): prompt_or_messages = self.apply_chat_template( prompt_or_messages, max_tokens, functions=params.get("functions"), tools=params.get("tools"), ) if isinstance(prompt_or_messages, list): prompt, token_ids = None, prompt_or_messages else: prompt, token_ids = prompt_or_messages, None token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens) try: sampling_params = SamplingParams( n=params.get("n", 1), presence_penalty=params.get("presence_penalty", 0.), frequency_penalty=params.get("frequency_penalty", 0.), temperature=params.get("temperature", 0.9), top_p=params.get("top_p", 0.8), stop=params.get("stop", []), stop_token_ids=params.get("stop_token_ids", []), max_tokens=params.get("max_tokens", 256), repetition_penalty=params.get("repetition_penalty", 1.03), min_p=params.get("min_p", 0.0), best_of=params.get("best_of", 1), ignore_eos=params.get("ignore_eos", False), use_beam_search=params.get("use_beam_search", False), skip_special_tokens=params.get("skip_special_tokens", True), spaces_between_special_tokens=params.get("spaces_between_special_tokens", True), ) result_generator = self.model.generate( prompt_or_messages if isinstance(prompt_or_messages, str) else None, sampling_params, request_id, token_ids, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e return result_generator @property def stop(self): """ Gets the stop property of the prompt adapter. Returns: The stop property of the prompt adapter, or None if it does not exist. """ return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None