codechat-playground / dialogues.py
Canstralian's picture
Update dialogues.py
b9eedf1 verified
raw
history blame
6.8 kB
# coding=utf-8
import json
import os
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from huggingface_hub import ModelHubMixin, hf_hub_download
# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="ModelHubMixin")
TEMPLATE_FILENAME = "dialogue_template.json"
IGNORE_INDEX = -100
@dataclass
class DialogueTemplate(ModelHubMixin):
"""Converts all turns of a dialogue between a user and assistant to a standardized format."""
system: str
messages: List[Dict[str, str]] = None
system_token: str = "<|system|>"
user_token: str = "<|user|>"
assistant_token: str = "<|assistant|>"
end_token: str = "<|end|>"
def __post_init__(self):
"""Ensure that messages is never None."""
if self.messages is None:
self.messages = []
def get_training_prompt(self) -> str:
if len(self.messages) == 0:
raise ValueError("Dialogue template must have at least one message.")
prompt = self.system_token + "\n" + self.system + self.end_token + "\n"
for message in self.messages:
if message["role"] == "user":
prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n"
else:
prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n"
return prompt
def get_inference_prompt(self) -> str:
if len(self.messages) == 0:
raise ValueError("Dialogue template must have at least one message.")
prompt = self.system_token + "\n" + self.system + self.end_token + "\n"
for message in self.messages:
if message["role"] == "user":
prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n"
else:
prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n"
prompt += self.assistant_token + "\n"
return prompt
def get_dialogue(self):
if len(self.messages) == 0:
raise ValueError("Dialogue template must have at least one message.")
prompt = ""
for message in self.messages:
if message["role"] == "user":
prompt += "\n\nHuman: " + message["content"]
else:
prompt += "\n\nAssistant: " + message["content"]
return prompt
def get_special_tokens(self) -> List[str]:
return [self.system_token, self.user_token, self.assistant_token, self.end_token]
def copy(self):
return DialogueTemplate(
system=self.system,
messages=self.messages,
system_token=self.system_token,
user_token=self.user_token,
assistant_token=self.assistant_token,
end_token=self.end_token,
)
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items()}
@classmethod
def from_dict(cls, data):
return DialogueTemplate(
system=data.get("system", ""),
messages=data.get("messages", None),
system_token=data.get("system_token", "<|system|>"),
user_token=data.get("user_token", "<|user|>"),
assistant_token=data.get("assistant_token", "<|assistant|>"),
end_token=data.get("end_token", "<|end|>"),
)
def _save_pretrained(self, save_directory: Union[str, Path]) -> None:
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True)
with open(save_directory / "dialogue_template.json", "w") as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def _from_pretrained(
cls: Type[T],
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Optional[Union[str, bool]],
**model_kwargs,
) -> T:
"""Loads the dialogue template from a local directory or the Huggingface Hub."""
if os.path.isdir(model_id):
print("Loading dialogue template from local directory")
template_file = os.path.join(model_id, TEMPLATE_FILENAME)
else:
template_file = hf_hub_download(
repo_id=model_id,
filename=TEMPLATE_FILENAME,
revision=revision or "main",
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
with open(template_file, "r") as f:
data = json.load(f)
return cls.from_dict(data=data)
# Default template
default_template = DialogueTemplate(
system="Below is a dialogue between a human user and an AI assistant. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed.",
)
# Supporting other templates
no_system_template = DialogueTemplate(system="")
alpaca_template = DialogueTemplate(
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
user_token="### Instruction:",
assistant_token="### Response:",
)
SUPPORTED_DIALOGUE_TEMPLATES = {
"default": default_template,
"no_system": no_system_template,
"alpaca": alpaca_template,
}
def get_dialogue_template(template: str) -> DialogueTemplate:
if template not in SUPPORTED_DIALOGUE_TEMPLATES:
raise ValueError(f"Template {template} is not supported!")
return SUPPORTED_DIALOGUE_TEMPLATES[template].copy()
def prepare_dialogue(example, dialogue_template, is_train=True):
if "messages" in example and example["messages"] is not None:
dialogue_template.messages = example["messages"]
elif "prompt" in example and "completion" in example:
dialogue_template.messages = [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["completion"]},
]
elif "prompt" in example:
dialogue_template.messages = [{"role": "user", "content": example["prompt"]}]
else:
raise ValueError(
f"Could not format example as dialogue! Require either `messages` or `[prompt, completion]` or `[prompt]` keys but found {list(example.keys())}"
)
if is_train:
example["text"] = dialogue_template.get_training_prompt()
else:
example["text"] = dialogue_template.get_inference_prompt()
return example