|
import codecs |
|
import collections.abc |
|
import logging |
|
from typing import Any, Dict, List, Tuple, Union |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler |
|
from llm_studio.src.datasets.text_utils import get_tokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CustomDataset(Dataset): |
|
"""Dataset for Causal Language modeling.""" |
|
|
|
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): |
|
""" |
|
Args: |
|
df: input DataFrame |
|
cfg: config with all the hyperparameters |
|
mode: dataset mode. One of {"train", "validation"} |
|
""" |
|
self.cfg = cfg |
|
self.mode = mode |
|
self.df = df.copy() |
|
|
|
self.tokenizer = get_tokenizer(self.cfg) |
|
self.conversation_chain_handler = ConversationChainHandler(self.df, cfg) |
|
|
|
def __len__(self) -> int: |
|
return len(self.conversation_chain_handler) |
|
|
|
def __getitem__(self, idx: int) -> Dict: |
|
"""Reads a single text observation.""" |
|
input_text_dict = self.conversation_chain_handler[idx] |
|
input_text_dict["systems"] = [ |
|
self.parse_system(self.cfg, system) for system in input_text_dict["systems"] |
|
] |
|
input_text_dict["prompts"] = [ |
|
self.parse_prompt(self.cfg, prompt) for prompt in input_text_dict["prompts"] |
|
] |
|
|
|
sample = dict() |
|
system_encoding, prompt_encodings, answer_encodings = self.get_encodings( |
|
input_text_dict=input_text_dict |
|
) |
|
|
|
input_ids = torch.cat( |
|
[ |
|
torch.cat([prompt_encoding, answer_encoding]) |
|
for prompt_encoding, answer_encoding in zip( |
|
prompt_encodings, answer_encodings |
|
) |
|
] |
|
) |
|
|
|
sample.update(self.get_labels(prompt_encodings, answer_encodings)) |
|
sample.update( |
|
self.pad_tokens( |
|
input_ids, |
|
attention_mask=torch.ones_like(input_ids), |
|
max_length=self.cfg.tokenizer.max_length, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
) |
|
) |
|
|
|
|
|
sample.update( |
|
self.pad_tokens( |
|
answer_encodings[-1], |
|
attention_mask=torch.ones_like(answer_encodings[-1]), |
|
max_length=self.cfg.tokenizer.max_length_answer, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
direction="right", |
|
prefix="answer_", |
|
) |
|
) |
|
|
|
|
|
answer_encodings[-1] = torch.empty(0) |
|
prompt_input_ids = torch.cat( |
|
[ |
|
torch.cat([prompt_encoding, answer_encoding]) |
|
for prompt_encoding, answer_encoding in zip( |
|
prompt_encodings, answer_encodings |
|
) |
|
] |
|
) |
|
sample.update( |
|
self.pad_tokens( |
|
prompt_input_ids, |
|
attention_mask=torch.ones_like(prompt_input_ids), |
|
max_length=self.cfg.tokenizer.max_length, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
prefix="prompt_", |
|
) |
|
) |
|
|
|
|
|
if sample["input_ids"][0] != self.tokenizer.pad_token_id: |
|
sample["input_ids"][: len(system_encoding)] = system_encoding |
|
if self.cfg.dataset.mask_prompt_labels and "labels" in sample.keys(): |
|
sample["labels"][: len(system_encoding)] = -100 |
|
if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id: |
|
sample["prompt_input_ids"][: len(system_encoding)] = system_encoding |
|
|
|
return sample |
|
|
|
@staticmethod |
|
def parse_prompt(cfg: Any, prompt: str): |
|
prompt = ( |
|
f"{codecs.decode(cfg.dataset.text_prompt_start, 'unicode_escape')}{prompt}" |
|
) |
|
if cfg.dataset.add_eos_token_to_prompt: |
|
prompt += cfg._tokenizer_eos_token |
|
prompt = ( |
|
f"{prompt}" |
|
f"{codecs.decode(cfg.dataset.text_answer_separator, 'unicode_escape')}" |
|
) |
|
return prompt |
|
|
|
@staticmethod |
|
def parse_system(cfg: Any, system: str): |
|
|
|
if system == "": |
|
return system |
|
system = ( |
|
f"{codecs.decode(cfg.dataset.text_system_start, 'unicode_escape')}{system}" |
|
) |
|
if cfg.dataset.add_eos_token_to_system: |
|
system += cfg._tokenizer_eos_token |
|
return system |
|
|
|
@staticmethod |
|
def batch_to_device( |
|
batch: Union[Dict, List, torch.Tensor], device: str |
|
) -> Union[Dict, List, torch.Tensor, str]: |
|
"""Function to send the batch to the device specified |
|
|
|
Args: |
|
batch: input batch |
|
device: device to send the data to |
|
Returns: |
|
batch with the elements on the device specified |
|
""" |
|
if isinstance(batch, torch.Tensor): |
|
return batch.to(device) |
|
elif isinstance(batch, (list, tuple)) and all( |
|
isinstance(item, str) for item in batch |
|
): |
|
|
|
return batch |
|
elif isinstance(batch, collections.abc.Mapping): |
|
return { |
|
key: CustomDataset.batch_to_device(value, device) |
|
for key, value in batch.items() |
|
} |
|
elif isinstance(batch, collections.abc.Sequence): |
|
return [CustomDataset.batch_to_device(value, device) for value in batch] |
|
else: |
|
raise ValueError(f"Can not move {type(batch)} to device.") |
|
|
|
@staticmethod |
|
def preprocess_dataframe(df: pd.DataFrame, cfg: Any, mode: str) -> pd.DataFrame: |
|
""" |
|
Preprocesses the input dataframe |
|
|
|
Args: |
|
df: the full training dataframe |
|
cfg: config |
|
mode: the mode. One of {"train", "validation"} |
|
Returns: |
|
the processed dataframe |
|
""" |
|
|
|
def personalize(text): |
|
text = text.replace("Open Assistant", cfg.dataset.chatbot_name) |
|
text = text.replace("Open-Assistant", cfg.dataset.chatbot_name) |
|
text = text.replace("open-assistant", cfg.dataset.chatbot_name) |
|
text = text.replace("OpenAssistant", cfg.dataset.chatbot_name) |
|
text = text.replace("open assistant", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assistand", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assitant", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assistent", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assisstant", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assitent", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assitiant", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assistiant", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assitan ", cfg.dataset.chatbot_name + " ") |
|
text = text.replace("Open Assistan ", cfg.dataset.chatbot_name + " ") |
|
text = text.replace("Open Asistant", cfg.dataset.chatbot_name) |
|
text = text.replace("Open Assiant", cfg.dataset.chatbot_name) |
|
text = text.replace("Assistant", cfg.dataset.chatbot_name) |
|
text = text.replace("LAION AI", cfg.dataset.chatbot_author) |
|
text = text.replace("LAION-AI", cfg.dataset.chatbot_author) |
|
text = text.replace("LAION,", cfg.dataset.chatbot_author + ",") |
|
text = text.replace("LAION.ai", cfg.dataset.chatbot_author) |
|
text = text.replace("LAION.", cfg.dataset.chatbot_author + ".") |
|
text = text.replace("LAION", cfg.dataset.chatbot_author) |
|
return text |
|
|
|
if cfg.dataset.personalize: |
|
for prompt_col in cfg.dataset.prompt_column: |
|
df[prompt_col] = df[prompt_col].apply(personalize) |
|
df[cfg.dataset.answer_column] = df[cfg.dataset.answer_column].apply( |
|
personalize |
|
) |
|
|
|
return df |
|
|
|
def get_train_collate_fn(self): |
|
""" |
|
Returns train batch collate function for the PyTorch Dataloader. |
|
By default returns None that uses the default PyTorch collate |
|
""" |
|
|
|
return None |
|
|
|
def get_validation_collate_fn(self): |
|
""" |
|
Return validation batch collate function for the PyTorch Dataloader. |
|
By default returns None that uses the default PyTorch collate |
|
""" |
|
|
|
return None |
|
|
|
def postprocess_batch_predictions(self, output: Dict) -> Dict: |
|
if "predicted_answer_ids" in output.keys(): |
|
predicted_text = [ |
|
self.tokenizer.decode(ids, skip_special_tokens=True).strip() |
|
for ids in output["predicted_answer_ids"] |
|
] |
|
|
|
output["predicted_text"] = np.array(predicted_text) |
|
del output["predicted_answer_ids"] |
|
return output |
|
|
|
@staticmethod |
|
def clean_output( |
|
output: Dict, |
|
cfg: Any, |
|
): |
|
output["predicted_text"] = output["predicted_text"].tolist() |
|
for j in range(len(output["predicted_text"])): |
|
curr_text = output["predicted_text"][j].strip() |
|
for stop_token in cfg.tokenizer._stop_words: |
|
if curr_text.find(stop_token) != -1: |
|
curr_text = curr_text[: curr_text.find(stop_token)] |
|
output["predicted_text"][j] = curr_text.strip() |
|
|
|
return output |
|
|
|
def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict: |
|
if not cfg.prediction.metric == "Perplexity": |
|
output = self.clean_output(output, cfg) |
|
|
|
output["target_text"] = self.conversation_chain_handler.answers |
|
|
|
metric_func, _, _ = cfg.prediction.metric_class.get(cfg.prediction.metric) |
|
|
|
if "GPT" in cfg.prediction.metric: |
|
metrics, explanations = metric_func( |
|
cfg, |
|
output, |
|
df, |
|
raw_results=True, |
|
) |
|
output["explanations"] = explanations |
|
else: |
|
metrics = metric_func( |
|
cfg, |
|
output, |
|
df, |
|
) |
|
output["metrics"] = metrics |
|
|
|
return output |
|
|
|
def format_output( |
|
self, cfg, df: pd.DataFrame, output: Dict |
|
) -> Tuple[Dict, pd.DataFrame]: |
|
output = { |
|
key: value |
|
for key, value in output.items() |
|
if key not in ["loss", "target", "losses"] |
|
} |
|
output.pop("target_text", None) |
|
|
|
|
|
end_conversation_ids = ( |
|
self.conversation_chain_handler.get_conversation_end_ids() |
|
) |
|
|
|
if "predicted_text" in output.keys(): |
|
output["predicted_text"] = np.array(output["predicted_text"]) |
|
|
|
if "logits" in output.keys(): |
|
output["logits"] = np.array(output["logits"].float()) |
|
|
|
if isinstance(cfg.dataset.prompt_column, tuple): |
|
for col in cfg.dataset.prompt_column: |
|
output[col] = df.loc[end_conversation_ids, col].values |
|
else: |
|
output[cfg.dataset.prompt_column] = df.loc[ |
|
end_conversation_ids, cfg.dataset.prompt_column |
|
].values |
|
|
|
if "predicted_text" in output.keys(): |
|
df[f"pred_{cfg.dataset.answer_column}"] = ( |
|
"NO ANSWER GENERATED. " |
|
"ONLY LAST ANSWER OF A CONVERSATION IS PREDICTED." |
|
) |
|
df.loc[end_conversation_ids, f"pred_{cfg.dataset.answer_column}"] = output[ |
|
"predicted_text" |
|
] |
|
return output, df |
|
|
|
@classmethod |
|
def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"): |
|
""" |
|
Quick check whether Dataframe and configurations are correctly set. |
|
""" |
|
if ( |
|
cfg.dataset.parent_id_column is not None |
|
and cfg.dataset.parent_id_column in df.columns |
|
and "id" in df.columns |
|
): |
|
assert ( |
|
df[cfg.dataset.parent_id_column] != df["id"] |
|
).all(), "Parent id column is the same as id column for some rows" |
|
assert (df[cfg.dataset.parent_id_column].fillna("") == "").sum() > 0, ( |
|
"Did not find any conversation start. " |
|
"Please ensure that some parent ids are empty." |
|
) |
|
|
|
assert cfg.dataset.answer_column in df.columns, ( |
|
f"Answer column {cfg.dataset.answer_column} not found in the " |
|
f"{mode} DataFrame." |
|
) |
|
assert df.shape[0] == df[[cfg.dataset.answer_column]].dropna().shape[0], ( |
|
f"The {mode} DataFrame" |
|
f" column {cfg.dataset.answer_column}" |
|
" contains missing values." |
|
) |
|
if cfg.dataset.parent_id_column != "None": |
|
assert ( |
|
"id" in df.columns |
|
), "When using parent column, the dataframe requires an 'id' column. " |
|
|
|
def get_labels(self, prompt_encodings, answer_encodings): |
|
labels = torch.cat( |
|
[ |
|
torch.cat([prompt_encoding, answer_encoding]) |
|
for prompt_encoding, answer_encoding in zip( |
|
prompt_encodings, answer_encodings |
|
) |
|
] |
|
).clone() |
|
|
|
if self.cfg.dataset.mask_prompt_labels: |
|
prompt_mask = torch.cat( |
|
[ |
|
torch.cat( |
|
[ |
|
torch.ones_like(prompt_encoding), |
|
torch.zeros_like(answer_encoding), |
|
] |
|
) |
|
for prompt_encoding, answer_encoding in zip( |
|
prompt_encodings, answer_encodings |
|
) |
|
] |
|
).to(torch.bool) |
|
labels.masked_fill_(prompt_mask, -100) |
|
if self.cfg.dataset.add_eos_token_to_answer: |
|
|
|
labels[-1] = self.tokenizer.eos_token_id |
|
if self.cfg.tokenizer.max_length < len(labels): |
|
labels = labels[-self.cfg.tokenizer.max_length :] |
|
|
|
sample = dict(labels=torch.full((self.cfg.tokenizer.max_length,), -100)) |
|
sample["labels"][-len(labels) :] = labels |
|
return sample |
|
|
|
def get_encodings(self, input_text_dict: Dict[str, List[str]]): |
|
""" |
|
Get encodings for a single conversation history. |
|
Args: |
|
input_text_dict: A dictionary containing the input text for a single sample. |
|
Contains the keys "systems", "prompts", "answers". |
|
System may be an empty string. |
|
""" |
|
encodings = [ |
|
self._get_sample_encoding(system, prompt, answer) |
|
for idx, (system, prompt, answer) in enumerate( |
|
zip( |
|
input_text_dict["systems"], |
|
input_text_dict["prompts"], |
|
input_text_dict["answers"], |
|
) |
|
) |
|
] |
|
|
|
if self.mode == "train": |
|
encodings = self.augment_data(encodings) |
|
|
|
system_encoding = encodings[0][0] |
|
prompt_encodings = [encoding[1] for encoding in encodings] |
|
answer_encodings = [encoding[2] for encoding in encodings] |
|
|
|
prompt_encodings[0] = torch.cat([system_encoding, prompt_encodings[0]]) |
|
return ( |
|
system_encoding, |
|
prompt_encodings, |
|
answer_encodings, |
|
) |
|
|
|
def augment_data(self, encodings): |
|
parent_encodings = encodings[:-1] |
|
|
|
parent_encodings = [ |
|
encoding |
|
for idx, encoding in enumerate(parent_encodings) |
|
if np.random.random() > self.cfg.augmentation.skip_parent_probability |
|
] |
|
|
|
if np.random.random() < self.cfg.augmentation.random_parent_probability: |
|
idx = np.random.randint(len(self.conversation_chain_handler.prompts)) |
|
parent_encodings = [ |
|
self._get_sample_encoding( |
|
self.parse_system( |
|
self.cfg, self.conversation_chain_handler.systems[idx] |
|
), |
|
self.parse_prompt( |
|
self.cfg, self.conversation_chain_handler.prompts[idx] |
|
), |
|
self.conversation_chain_handler.answers[idx], |
|
) |
|
] + parent_encodings[1:] |
|
encodings = parent_encodings + [encodings[-1]] |
|
return encodings |
|
|
|
def _get_sample_encoding(self, system: str, prompt: str, answer: str) -> List: |
|
if len(system) > 0: |
|
system_encoding = self.encode( |
|
self.tokenizer, system, self.cfg.tokenizer.max_length_prompt, "right" |
|
)["input_ids"] |
|
else: |
|
system_encoding = torch.empty(0) |
|
prompt_encoding = self.encode( |
|
self.tokenizer, prompt, self.cfg.tokenizer.max_length_prompt, "left" |
|
)["input_ids"] |
|
max_length_answer = self.cfg.tokenizer.max_length_answer - int( |
|
self.cfg.dataset.add_eos_token_to_answer |
|
) |
|
answer_encoding = self.encode( |
|
self.tokenizer, answer, max_length_answer, "right" |
|
)["input_ids"] |
|
if self.cfg.dataset.add_eos_token_to_answer: |
|
answer_encoding = torch.cat( |
|
[ |
|
answer_encoding, |
|
torch.Tensor([self.tokenizer.eos_token_id]), |
|
], |
|
dim=0, |
|
) |
|
|
|
return [system_encoding, prompt_encoding, answer_encoding] |
|
|
|
@staticmethod |
|
def pad_tokens( |
|
input_ids, |
|
attention_mask, |
|
max_length, |
|
pad_token_id, |
|
direction="left", |
|
prefix="", |
|
): |
|
sample = {} |
|
|
|
if max_length < len(input_ids): |
|
input_ids = input_ids[-max_length:] |
|
attention_mask = attention_mask[-max_length:] |
|
|
|
if len(input_ids) > 0: |
|
if direction == "left": |
|
sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id) |
|
sample[f"{prefix}input_ids"][-len(input_ids) :] = input_ids |
|
sample[f"{prefix}attention_mask"] = torch.zeros(max_length) |
|
sample[f"{prefix}attention_mask"][-len(input_ids) :] = attention_mask |
|
else: |
|
sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id) |
|
sample[f"{prefix}input_ids"][: len(input_ids)] = input_ids |
|
sample[f"{prefix}attention_mask"] = torch.zeros(max_length) |
|
sample[f"{prefix}attention_mask"][: len(input_ids)] = attention_mask |
|
else: |
|
|
|
sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id) |
|
sample[f"{prefix}attention_mask"] = torch.zeros(max_length) |
|
|
|
return sample |
|
|
|
@staticmethod |
|
def encode(tokenizer, text: str, max_length: int, truncation_side: str) -> Dict: |
|
encodings = tokenizer(text, return_tensors="pt", add_special_tokens=False) |
|
encodings["input_ids"] = encodings["input_ids"][0] |
|
encodings["attention_mask"] = encodings["attention_mask"][0] |
|
if truncation_side == "right": |
|
encodings["input_ids"] = encodings["input_ids"][:max_length] |
|
encodings["attention_mask"] = encodings["attention_mask"][:max_length] |
|
else: |
|
encodings["input_ids"] = encodings["input_ids"][-max_length:] |
|
encodings["attention_mask"] = encodings["attention_mask"][-max_length:] |
|
return encodings |
|
|