Last commit not found
""" | |
Prompt Strategy for finetuning Llama2 chat models | |
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation. | |
This implementation is based on the Vicuna PR and the fastchat repo, see also: | |
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847 | |
Use dataset type: "llama2_chat" in conig.yml to use this prompt style. | |
E.g. in the config.yml: | |
``` | |
datasets: | |
- path: llama_finetune_train.jsonl | |
type: llama2_chat | |
``` | |
The dataset itself should look like this: | |
``` | |
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]} | |
``` | |
in a jsonl file. The first message should be from the human, the second from gpt. | |
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns). | |
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! | |
""" | |
import logging | |
from dataclasses import dataclass, field | |
from typing import Generator, List, Sequence | |
from axolotl.prompt_tokenizers import PromptTokenizingStrategy | |
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE | |
class Llama2ChatConversation: | |
"""A class that manages prompt templates and keeps all conversation history. | |
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py""" | |
name: str = "llama2" | |
# The system prompt | |
system: str = ( | |
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " | |
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " | |
"Please ensure that your responses are socially unbiased and positive in nature.\n\n" | |
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " | |
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n" | |
) | |
roles: Sequence[str] = ("[INST]", "[/INST]") | |
messages: List[List[str]] = field(default_factory=list) | |
offset: int = 0 | |
sep = " " | |
sep2 = " </s><s>" | |
stop_token_ids = [2] | |
def get_prompt(self) -> str: | |
"""Get the prompt for generation.""" | |
seps = [self.sep, self.sep2] | |
ret = "" | |
for i, (role, message) in enumerate(self.messages): | |
if (i == len(self.messages) - 1) and (role == self.roles[0]): | |
# last message is from user (due to length), | |
# return prompt without it for training | |
return ret | |
if i == 0: | |
ret += self.system + message.strip() | |
else: | |
ret += role + " " + message.strip() + seps[i % 2] | |
return ret | |
def append_message(self, role: str, message: str): | |
"""Append a new message.""" | |
self.messages.append([role, message]) | |
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): | |
""" | |
Tokenizing strategy for ShareGPT prompts. | |
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.sequence_len = 4096 | |
self.tokenizer.add_special_tokens({"pad_token": "<pad>"}) | |
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json | |
def tokenize_prompt(self, prompt): | |
conv = next(self.prompter.build_prompt(prompt)) | |
conversation_str = conv.get_prompt() | |
# Tokenize conversations | |
input_ids = self.tokenizer( | |
conversation_str, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=self.sequence_len, | |
truncation=True, | |
).input_ids[0] | |
target = input_ids.clone() | |
# Mask targets. Only compute loss on the assistant outputs. | |
sep = conv.roles[1] | |
total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) | |
turns = conversation_str.split(conv.sep2) | |
cur_len = 1 | |
target[:cur_len] = IGNORE_TOKEN_ID | |
for turn in turns: | |
if turn == "": | |
break | |
turn_len = len(self.tokenizer(turn).input_ids) | |
parts = turn.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
# "-1" is hardcoded for the LLaMA tokenizer to make the offset correct. | |
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1 | |
# Ignore the user instructions | |
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID | |
cur_len += turn_len + 2 # due to length of role token | |
target[cur_len:] = IGNORE_TOKEN_ID | |
if cur_len < self.sequence_len: | |
if cur_len != total_len: | |
target[:] = IGNORE_TOKEN_ID | |
logging.warning( | |
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." | |
f" (ignored)" | |
) | |
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist() | |
input_ids = input_ids.tolist() | |
target = target.tolist() | |
# this is a fix for the tokenizer which tokenizes [ differently with eos tokens and | |
# follows the original llama implementation | |
for i in range(2, total_len - 2): | |
if input_ids[i] == 29961: | |
input_ids[i] = 518 | |
if target[i] == 29961: | |
target[i] = 518 | |
return { | |
"input_ids": input_ids, | |
"labels": target, | |
"attention_mask": attention_mask, | |
} | |
class Llama2ChatPrompter: # pylint: disable=too-few-public-methods | |
""" | |
A prompter that generates prompts for Llama2 models. | |
""" | |
system_prompt = ( | |
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " | |
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " | |
"Please ensure that your responses are socially unbiased and positive in nature.\n\n" | |
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " | |
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n" | |
) | |
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]: | |
# see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78 | |
source = source["conversations"] # fix data structure for datasets | |
# if system prompt provided, use it | |
if source[0]["from"] == "system": | |
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n" | |
source = source[1:] | |
else: | |
system = self.system_prompt | |
conv = Llama2ChatConversation(system=system) | |
if len(source) < 2: | |
# If there isn't a back and forth conversation, ignore it | |
# also happens on the data splitting leaving empty conversations | |
raise IndexError | |
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
if roles[source[0]["from"]] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] # pylint: disable=R0801 | |
for j, sentence in enumerate(source): | |
role = roles[sentence["from"]] | |
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE | |
if sentence["value"]: | |
conv.append_message(role, sentence["value"]) | |
yield conv | |
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy: | |
return LLama2ChatTokenizingStrategy( | |
Llama2ChatPrompter(), | |
tokenizer, | |
cfg.train_on_inputs, | |
cfg.sequence_len, | |
) | |