Spaces:
Running
Running
File size: 6,530 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import asyncio
from typing import (
Optional,
List,
Dict,
Any,
AsyncIterator,
Union,
)
from fastapi import HTTPException
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from api.adapter import get_prompt_adapter
from api.generation import build_qwen_chat_input
class VllmEngine:
def __init__(
self,
model: AsyncLLMEngine,
tokenizer: PreTrainedTokenizer,
model_name: str,
prompt_name: Optional[str] = None,
context_len: Optional[int] = -1,
):
"""
Initializes the VLLMEngine object.
Args:
model: The AsyncLLMEngine object.
tokenizer: The PreTrainedTokenizer object.
model_name: The name of the model.
prompt_name: The name of the prompt (optional).
context_len: The length of the context (optional, default=-1).
"""
self.model = model
self.model_name = model_name.lower()
self.tokenizer = tokenizer
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
model_config = asyncio.run(self.model.get_model_config())
if "qwen" in self.model_name:
self.max_model_len = context_len if context_len > 0 else 8192
else:
self.max_model_len = model_config.max_model_len
def apply_chat_template(
self,
messages: List[ChatCompletionMessageParam],
max_tokens: Optional[int] = 256,
functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
) -> Union[str, List[int]]:
"""
Applies a chat template to the given messages and returns the processed output.
Args:
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
max_tokens: The maximum number of tokens in the output (optional, default=256).
functions: A dictionary or list of dictionaries representing the functions to be applied (optional).
tools: A list of dictionaries representing the tools to be used (optional).
Returns:
Union[str, List[int]]: The processed output as a string or a list of integers.
"""
if self.prompt_adapter.function_call_available:
messages = self.prompt_adapter.postprocess_messages(
messages, functions, tools,
)
if functions or tools:
logger.debug(f"==== Messages with tools ====\n{messages}")
if "chatglm3" in self.model_name:
query, role = messages[-1]["content"], messages[-1]["role"]
return self.tokenizer.build_chat_input(
query, history=messages[:-1], role=role
)["input_ids"][0].tolist()
elif "qwen" in self.model_name:
return build_qwen_chat_input(
self.tokenizer,
messages,
self.max_model_len,
max_tokens,
functions,
tools,
)
else:
return self.prompt_adapter.apply_chat_template(messages)
def convert_to_inputs(
self,
prompt: Optional[str] = None,
token_ids: Optional[List[int]] = None,
max_tokens: Optional[int] = 256,
) -> List[int]:
max_input_tokens = self.max_model_len - max_tokens
input_ids = token_ids or self.tokenizer(prompt).input_ids
return input_ids[-max_input_tokens:]
def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator:
"""
Generates text based on the given parameters and request ID.
Args:
params (Dict[str, Any]): A dictionary of parameters for text generation.
request_id (str): The ID of the request.
Yields:
Any: The generated text.
"""
max_tokens = params.get("max_tokens", 256)
prompt_or_messages = params.get("prompt_or_messages")
if isinstance(prompt_or_messages, list):
prompt_or_messages = self.apply_chat_template(
prompt_or_messages,
max_tokens,
functions=params.get("functions"),
tools=params.get("tools"),
)
if isinstance(prompt_or_messages, list):
prompt, token_ids = None, prompt_or_messages
else:
prompt, token_ids = prompt_or_messages, None
token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens)
try:
sampling_params = SamplingParams(
n=params.get("n", 1),
presence_penalty=params.get("presence_penalty", 0.),
frequency_penalty=params.get("frequency_penalty", 0.),
temperature=params.get("temperature", 0.9),
top_p=params.get("top_p", 0.8),
stop=params.get("stop", []),
stop_token_ids=params.get("stop_token_ids", []),
max_tokens=params.get("max_tokens", 256),
repetition_penalty=params.get("repetition_penalty", 1.03),
min_p=params.get("min_p", 0.0),
best_of=params.get("best_of", 1),
ignore_eos=params.get("ignore_eos", False),
use_beam_search=params.get("use_beam_search", False),
skip_special_tokens=params.get("skip_special_tokens", True),
spaces_between_special_tokens=params.get("spaces_between_special_tokens", True),
)
result_generator = self.model.generate(
prompt_or_messages if isinstance(prompt_or_messages, str) else None,
sampling_params,
request_id,
token_ids,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return result_generator
@property
def stop(self):
"""
Gets the stop property of the prompt adapter.
Returns:
The stop property of the prompt adapter, or None if it does not exist.
"""
return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None
|