|
""" |
|
HF Chat Templates prompt strategy |
|
""" |
|
|
|
import logging |
|
from typing import Any, Dict, List, Optional |
|
|
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy |
|
from axolotl.prompters import Prompter |
|
from axolotl.utils.chat_templates import chat_templates |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
|
|
class ChatTemplatePrompter(Prompter): |
|
"""prompter for HF chat templates""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer, |
|
chat_template=None, |
|
max_length=2048, |
|
message_field_role: str = "from", |
|
message_field_content: str = "value", |
|
roles: Optional[Dict[str, List[str]]] = None, |
|
): |
|
if roles: |
|
self.roles = {s: t for t, sources in roles.items() for s in sources} |
|
else: |
|
self.roles = { |
|
"human": "user", |
|
"user": "user", |
|
"assistant": "assistant", |
|
"gpt": "assistant", |
|
"system": "system", |
|
} |
|
self.message_field_role = message_field_role |
|
self.message_field_content = message_field_content |
|
self.tokenizer = tokenizer |
|
self.chat_template = chat_template |
|
self.max_length = max_length |
|
|
|
def build_prompt(self, conversation, add_generation_prompt=False): |
|
turns = [ |
|
{ |
|
"role": self.roles[t[self.message_field_role]], |
|
"content": t[self.message_field_content], |
|
} |
|
for t in conversation |
|
] |
|
|
|
return self.tokenizer.apply_chat_template( |
|
turns, |
|
truncation=True, |
|
max_length=self.max_length, |
|
add_generation_prompt=add_generation_prompt, |
|
chat_template=self.chat_template, |
|
) |
|
|
|
|
|
class ChatTemplateStrategy(PromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for instruction-based prompts. |
|
""" |
|
|
|
_messages = "conversations" |
|
|
|
@property |
|
def messages(self): |
|
return self._messages |
|
|
|
@messages.setter |
|
def messages(self, messages): |
|
self._messages = messages |
|
|
|
def tokenize_prompt(self, prompt): |
|
turns = self.get_conversation_thread(prompt) |
|
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True) |
|
input_ids = self.prompter.build_prompt(turns) |
|
|
|
if not self.train_on_inputs: |
|
user_prompt_len = len(prompt_ids) |
|
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] |
|
else: |
|
labels = input_ids |
|
|
|
tokenized_prompt = { |
|
"input_ids": input_ids, |
|
"labels": labels, |
|
"attention_mask": [1] * len(input_ids), |
|
} |
|
|
|
return tokenized_prompt |
|
|
|
def get_conversation_thread(self, prompt): |
|
return prompt[self.messages] |
|
|
|
|
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): |
|
chat_template = ( |
|
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml" |
|
) |
|
message_field_role = ( |
|
ds_cfg["message_field_role"] |
|
if ds_cfg and "message_field_role" in ds_cfg |
|
else "from" |
|
) |
|
message_field_content = ( |
|
ds_cfg["message_field_content"] |
|
if ds_cfg and "message_field_content" in ds_cfg |
|
else "value" |
|
) |
|
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None |
|
|
|
strategy = ChatTemplateStrategy( |
|
ChatTemplatePrompter( |
|
tokenizer, |
|
chat_templates(chat_template), |
|
message_field_role=message_field_role, |
|
message_field_content=message_field_content, |
|
roles=roles, |
|
), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"): |
|
strategy.messages = ds_cfg["field_messages"] |
|
return strategy |
|
|