Spaces:
Runtime error
Runtime error
import abc | |
import logging | |
from collections.abc import Sequence | |
from typing import Any, Literal | |
from llama_index.core.llms import ChatMessage, MessageRole | |
logger = logging.getLogger(__name__) | |
class AbstractPromptStyle(abc.ABC): | |
"""Abstract class for prompt styles. | |
This class is used to format a series of messages into a prompt that can be | |
understood by the models. A series of messages represents the interaction(s) | |
between a user and an assistant. This series of messages can be considered as a | |
session between a user X and an assistant Y.This session holds, through the | |
messages, the state of the conversation. This session, to be understood by the | |
model, needs to be formatted into a prompt (i.e. a string that the models | |
can understand). Prompts can be formatted in different ways, | |
depending on the model. | |
The implementations of this class represent the different ways to format a | |
series of messages into a prompt. | |
""" | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
logger.debug("Initializing prompt_style=%s", self.__class__.__name__) | |
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | |
pass | |
def _completion_to_prompt(self, completion: str) -> str: | |
pass | |
def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | |
prompt = self._messages_to_prompt(messages) | |
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt) | |
return prompt | |
def completion_to_prompt(self, completion: str) -> str: | |
prompt = self._completion_to_prompt(completion) | |
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt) | |
return prompt | |
class DefaultPromptStyle(AbstractPromptStyle): | |
"""Default prompt style that uses the defaults from llama_utils. | |
It basically passes None to the LLM, indicating it should use | |
the default functions. | |
""" | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
super().__init__(*args, **kwargs) | |
# Hacky way to override the functions | |
# Override the functions to be None, and pass None to the LLM. | |
self.messages_to_prompt = None # type: ignore[method-assign, assignment] | |
self.completion_to_prompt = None # type: ignore[method-assign, assignment] | |
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | |
return "" | |
def _completion_to_prompt(self, completion: str) -> str: | |
return "" | |
class Llama2PromptStyle(AbstractPromptStyle): | |
"""Simple prompt style that uses llama 2 prompt style. | |
Inspired by llama_index/legacy/llms/llama_utils.py | |
It transforms the sequence of messages into a prompt that should look like: | |
```text | |
<s> [INST] <<SYS>> your system prompt here. <</SYS>> | |
user message here [/INST] assistant (model) response here </s> | |
``` | |
""" | |
BOS, EOS = "<s>", "</s>" | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
DEFAULT_SYSTEM_PROMPT = """\ | |
You are a helpful, respectful and honest assistant. \ | |
Always answer as helpfully as possible and follow ALL given instructions. \ | |
Do not speculate or make up information. \ | |
Do not reference any given instructions or context. \ | |
""" | |
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | |
string_messages: list[str] = [] | |
if messages[0].role == MessageRole.SYSTEM: | |
# pull out the system message (if it exists in messages) | |
system_message_str = messages[0].content or "" | |
messages = messages[1:] | |
else: | |
system_message_str = self.DEFAULT_SYSTEM_PROMPT | |
system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}" | |
for i in range(0, len(messages), 2): | |
# first message should always be a user | |
user_message = messages[i] | |
assert user_message.role == MessageRole.USER | |
if i == 0: | |
# make sure system prompt is included at the start | |
str_message = f"{self.BOS} {self.B_INST} {system_message_str} " | |
else: | |
# end previous user-assistant interaction | |
string_messages[-1] += f" {self.EOS}" | |
# no need to include system prompt | |
str_message = f"{self.BOS} {self.B_INST} " | |
# include user message content | |
str_message += f"{user_message.content} {self.E_INST}" | |
if len(messages) > (i + 1): | |
# if assistant message exists, add to str_message | |
assistant_message = messages[i + 1] | |
assert assistant_message.role == MessageRole.ASSISTANT | |
str_message += f" {assistant_message.content}" | |
string_messages.append(str_message) | |
return "".join(string_messages) | |
def _completion_to_prompt(self, completion: str) -> str: | |
system_prompt_str = self.DEFAULT_SYSTEM_PROMPT | |
return ( | |
f"{self.BOS} {self.B_INST} {self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} " | |
f"{completion.strip()} {self.E_INST}" | |
) | |
class TagPromptStyle(AbstractPromptStyle): | |
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`. | |
It transforms the sequence of messages into a prompt that should look like: | |
```text | |
<|system|>: your system prompt here. | |
<|user|>: user message here | |
(possibly with context and question) | |
<|assistant|>: assistant (model) response here. | |
``` | |
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2? | |
""" | |
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | |
"""Format message to prompt with `<|ROLE|>: MSG` style.""" | |
prompt = "" | |
for message in messages: | |
role = message.role | |
content = message.content or "" | |
message_from_user = f"<|{role.lower()}|>: {content.strip()}" | |
message_from_user += "\n" | |
prompt += message_from_user | |
# we are missing the last <|assistant|> tag that will trigger a completion | |
prompt += "<|assistant|>: " | |
return prompt | |
def _completion_to_prompt(self, completion: str) -> str: | |
return self._messages_to_prompt( | |
[ChatMessage(content=completion, role=MessageRole.USER)] | |
) | |
class MistralPromptStyle(AbstractPromptStyle): | |
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | |
prompt = "<s>" | |
for message in messages: | |
role = message.role | |
content = message.content or "" | |
if role.lower() == "system": | |
message_from_user = f"[INST] {content.strip()} [/INST]" | |
prompt += message_from_user | |
elif role.lower() == "user": | |
prompt += "</s>" | |
message_from_user = f"[INST] {content.strip()} [/INST]" | |
prompt += message_from_user | |
return prompt | |
def _completion_to_prompt(self, completion: str) -> str: | |
return self._messages_to_prompt( | |
[ChatMessage(content=completion, role=MessageRole.USER)] | |
) | |
class ChatMLPromptStyle(AbstractPromptStyle): | |
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | |
prompt = "<|im_start|>system\n" | |
for message in messages: | |
role = message.role | |
content = message.content or "" | |
if role.lower() == "system": | |
message_from_user = f"{content.strip()}" | |
prompt += message_from_user | |
elif role.lower() == "user": | |
prompt += "<|im_end|>\n<|im_start|>user\n" | |
message_from_user = f"{content.strip()}<|im_end|>\n" | |
prompt += message_from_user | |
prompt += "<|im_start|>assistant\n" | |
return prompt | |
def _completion_to_prompt(self, completion: str) -> str: | |
return self._messages_to_prompt( | |
[ChatMessage(content=completion, role=MessageRole.USER)] | |
) | |
def get_prompt_style( | |
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None | |
) -> AbstractPromptStyle: | |
"""Get the prompt style to use from the given string. | |
:param prompt_style: The prompt style to use. | |
:return: The prompt style to use. | |
""" | |
if prompt_style is None or prompt_style == "default": | |
return DefaultPromptStyle() | |
elif prompt_style == "llama2": | |
return Llama2PromptStyle() | |
elif prompt_style == "tag": | |
return TagPromptStyle() | |
elif prompt_style == "mistral": | |
return MistralPromptStyle() | |
elif prompt_style == "chatml": | |
return ChatMLPromptStyle() | |
raise ValueError(f"Unknown prompt_style='{prompt_style}'") | |