File size: 6,530 Bytes
ca56e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import asyncio
from typing import (
    Optional,
    List,
    Dict,
    Any,
    AsyncIterator,
    Union,
)

from fastapi import HTTPException
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams

from api.adapter import get_prompt_adapter
from api.generation import build_qwen_chat_input


class VllmEngine:
    def __init__(
        self,
        model: AsyncLLMEngine,
        tokenizer: PreTrainedTokenizer,
        model_name: str,
        prompt_name: Optional[str] = None,
        context_len: Optional[int] = -1,
    ):
        """
        Initializes the VLLMEngine object.

        Args:
            model: The AsyncLLMEngine object.
            tokenizer: The PreTrainedTokenizer object.
            model_name: The name of the model.
            prompt_name: The name of the prompt (optional).
            context_len: The length of the context (optional, default=-1).
        """
        self.model = model
        self.model_name = model_name.lower()
        self.tokenizer = tokenizer
        self.prompt_name = prompt_name.lower() if prompt_name is not None else None
        self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)

        model_config = asyncio.run(self.model.get_model_config())
        if "qwen" in self.model_name:
            self.max_model_len = context_len if context_len > 0 else 8192
        else:
            self.max_model_len = model_config.max_model_len

    def apply_chat_template(
        self,
        messages: List[ChatCompletionMessageParam],
        max_tokens: Optional[int] = 256,
        functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
    ) -> Union[str, List[int]]:
        """
        Applies a chat template to the given messages and returns the processed output.

        Args:
            messages: A list of ChatCompletionMessageParam objects representing the chat messages.
            max_tokens: The maximum number of tokens in the output (optional, default=256).
            functions: A dictionary or list of dictionaries representing the functions to be applied (optional).
            tools: A list of dictionaries representing the tools to be used (optional).

        Returns:
            Union[str, List[int]]: The processed output as a string or a list of integers.
        """
        if self.prompt_adapter.function_call_available:
            messages = self.prompt_adapter.postprocess_messages(
                messages, functions, tools,
            )
            if functions or tools:
                logger.debug(f"==== Messages with tools ====\n{messages}")

        if "chatglm3" in self.model_name:
            query, role = messages[-1]["content"], messages[-1]["role"]
            return self.tokenizer.build_chat_input(
                query, history=messages[:-1], role=role
            )["input_ids"][0].tolist()
        elif "qwen" in self.model_name:
            return build_qwen_chat_input(
                self.tokenizer,
                messages,
                self.max_model_len,
                max_tokens,
                functions,
                tools,
            )
        else:
            return self.prompt_adapter.apply_chat_template(messages)

    def convert_to_inputs(
        self,
        prompt: Optional[str] = None,
        token_ids: Optional[List[int]] = None,
        max_tokens: Optional[int] = 256,
    ) -> List[int]:
        max_input_tokens = self.max_model_len - max_tokens
        input_ids = token_ids or self.tokenizer(prompt).input_ids
        return input_ids[-max_input_tokens:]

    def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator:
        """
        Generates text based on the given parameters and request ID.

        Args:
            params (Dict[str, Any]): A dictionary of parameters for text generation.
            request_id (str): The ID of the request.

        Yields:
            Any: The generated text.
        """
        max_tokens = params.get("max_tokens", 256)
        prompt_or_messages = params.get("prompt_or_messages")
        if isinstance(prompt_or_messages, list):
            prompt_or_messages = self.apply_chat_template(
                prompt_or_messages,
                max_tokens,
                functions=params.get("functions"),
                tools=params.get("tools"),
            )

        if isinstance(prompt_or_messages, list):
            prompt, token_ids = None, prompt_or_messages
        else:
            prompt, token_ids = prompt_or_messages, None

        token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens)
        try:
            sampling_params = SamplingParams(
                n=params.get("n", 1),
                presence_penalty=params.get("presence_penalty", 0.),
                frequency_penalty=params.get("frequency_penalty", 0.),
                temperature=params.get("temperature", 0.9),
                top_p=params.get("top_p", 0.8),
                stop=params.get("stop", []),
                stop_token_ids=params.get("stop_token_ids", []),
                max_tokens=params.get("max_tokens", 256),
                repetition_penalty=params.get("repetition_penalty", 1.03),
                min_p=params.get("min_p", 0.0),
                best_of=params.get("best_of", 1),
                ignore_eos=params.get("ignore_eos", False),
                use_beam_search=params.get("use_beam_search", False),
                skip_special_tokens=params.get("skip_special_tokens", True),
                spaces_between_special_tokens=params.get("spaces_between_special_tokens", True),
            )
            result_generator = self.model.generate(
                prompt_or_messages if isinstance(prompt_or_messages, str) else None,
                sampling_params,
                request_id,
                token_ids,
            )
        except ValueError as e:
            raise HTTPException(status_code=400, detail=str(e)) from e

        return result_generator

    @property
    def stop(self):
        """
        Gets the stop property of the prompt adapter.

        Returns:
            The stop property of the prompt adapter, or None if it does not exist.
        """
        return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None