# coding=utf-8 import json import os from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Type, TypeVar, Union from huggingface_hub import ModelHubMixin, hf_hub_download # Generic variable that is either ModelHubMixin or a subclass thereof T = TypeVar("T", bound="ModelHubMixin") TEMPLATE_FILENAME = "dialogue_template.json" IGNORE_INDEX = -100 @dataclass class DialogueTemplate(ModelHubMixin): """Converts all turns of a dialogue between a user and assistant to a standardized format.""" system: str messages: List[Dict[str, str]] = None system_token: str = "<|system|>" user_token: str = "<|user|>" assistant_token: str = "<|assistant|>" end_token: str = "<|end|>" def __post_init__(self): """Ensure that messages is never None.""" if self.messages is None: self.messages = [] def get_training_prompt(self) -> str: if len(self.messages) == 0: raise ValueError("Dialogue template must have at least one message.") prompt = self.system_token + "\n" + self.system + self.end_token + "\n" for message in self.messages: if message["role"] == "user": prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n" else: prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n" return prompt def get_inference_prompt(self) -> str: if len(self.messages) == 0: raise ValueError("Dialogue template must have at least one message.") prompt = self.system_token + "\n" + self.system + self.end_token + "\n" for message in self.messages: if message["role"] == "user": prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n" else: prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n" prompt += self.assistant_token + "\n" return prompt def get_dialogue(self): if len(self.messages) == 0: raise ValueError("Dialogue template must have at least one message.") prompt = "" for message in self.messages: if message["role"] == "user": prompt += "\n\nHuman: " + message["content"] else: prompt += "\n\nAssistant: " + message["content"] return prompt def get_special_tokens(self) -> List[str]: return [self.system_token, self.user_token, self.assistant_token, self.end_token] def copy(self): return DialogueTemplate( system=self.system, messages=self.messages, system_token=self.system_token, user_token=self.user_token, assistant_token=self.assistant_token, end_token=self.end_token, ) def to_dict(self) -> Dict[str, Any]: return {k: v for k, v in asdict(self).items()} @classmethod def from_dict(cls, data): return DialogueTemplate( system=data.get("system", ""), messages=data.get("messages", None), system_token=data.get("system_token", "<|system|>"), user_token=data.get("user_token", "<|user|>"), assistant_token=data.get("assistant_token", "<|assistant|>"), end_token=data.get("end_token", "<|end|>"), ) def _save_pretrained(self, save_directory: Union[str, Path]) -> None: save_directory = Path(save_directory) save_directory.mkdir(exist_ok=True) with open(save_directory / "dialogue_template.json", "w") as f: json.dump(self.to_dict(), f, indent=2) @classmethod def _from_pretrained( cls: Type[T], *, model_id: str, revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Optional[Union[str, bool]], **model_kwargs, ) -> T: """Loads the dialogue template from a local directory or the Huggingface Hub.""" if os.path.isdir(model_id): print("Loading dialogue template from local directory") template_file = os.path.join(model_id, TEMPLATE_FILENAME) else: template_file = hf_hub_download( repo_id=model_id, filename=TEMPLATE_FILENAME, revision=revision or "main", cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) with open(template_file, "r") as f: data = json.load(f) return cls.from_dict(data=data) # Default template default_template = DialogueTemplate( system="Below is a dialogue between a human user and an AI assistant. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed.", ) # Supporting other templates no_system_template = DialogueTemplate(system="") alpaca_template = DialogueTemplate( system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", user_token="### Instruction:", assistant_token="### Response:", ) SUPPORTED_DIALOGUE_TEMPLATES = { "default": default_template, "no_system": no_system_template, "alpaca": alpaca_template, } def get_dialogue_template(template: str) -> DialogueTemplate: if template not in SUPPORTED_DIALOGUE_TEMPLATES: raise ValueError(f"Template {template} is not supported!") return SUPPORTED_DIALOGUE_TEMPLATES[template].copy() def prepare_dialogue(example, dialogue_template, is_train=True): if "messages" in example and example["messages"] is not None: dialogue_template.messages = example["messages"] elif "prompt" in example and "completion" in example: dialogue_template.messages = [ {"role": "user", "content": example["prompt"]}, {"role": "assistant", "content": example["completion"]}, ] elif "prompt" in example: dialogue_template.messages = [{"role": "user", "content": example["prompt"]}] else: raise ValueError( f"Could not format example as dialogue! Require either `messages` or `[prompt, completion]` or `[prompt]` keys but found {list(example.keys())}" ) if is_train: example["text"] = dialogue_template.get_training_prompt() else: example["text"] = dialogue_template.get_inference_prompt() return example