File size: 3,825 Bytes
c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b b752080 cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d cc11c6b c1a7b3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
HF Chat Templates prompt strategy
"""
import logging
from typing import Any, Dict, List, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
LOG = logging.getLogger("axolotl")
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
def __init__(
self,
tokenizer,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
roles: Optional[Dict[str, List[str]]] = None,
):
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
self.roles = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
"system": "system",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
}
for t in conversation
]
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
_messages = "conversations"
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
return prompt[self.messages]
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
message_field_role = (
ds_cfg["message_field_role"]
if ds_cfg and "message_field_role" in ds_cfg
else "from"
)
message_field_content = (
ds_cfg["message_field_content"]
if ds_cfg and "message_field_content" in ds_cfg
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_templates(chat_template),
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
|