Spaces:
Sleeping
Sleeping
# prompt_builder.py | |
from typing import Protocol, List, Tuple | |
from transformers import AutoTokenizer | |
class PromptTemplate(Protocol): | |
"""Protocol for prompt templates.""" | |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: | |
pass | |
class LlamaPromptTemplate: | |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str: | |
system_message = f"Please assist based on the following context: {context}" | |
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" | |
for user_msg, assistant_msg in chat_history[-max_history_turns:]: | |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" | |
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" | |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" | |
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
return prompt | |
class TransformersPromptTemplate: | |
def __init__(self, model_path: str): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: | |
messages = [ | |
{ | |
"role": "system", | |
"content": f"Please assist based on the following context: {context}", | |
} | |
] | |
for user_msg, assistant_msg in chat_history: | |
messages.extend([ | |
{"role": "user", "content": user_msg}, | |
{"role": "assistant", "content": assistant_msg} | |
]) | |
messages.append({"role": "user", "content": user_input}) | |
tokenized_chat = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
return tokenized_chat | |