Spaces:
Sleeping
Sleeping
import logging | |
from typing import Any, Dict | |
import pandas as pd | |
import torch | |
import llm_studio.src.datasets.text_causal_language_modeling_ds as text_causal_language_modeling_ds # noqa: [F401] | |
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler | |
from llm_studio.src.utils.utils import PatchedAttribute | |
logger = logging.getLogger(__name__) | |
class CustomDataset(text_causal_language_modeling_ds.CustomDataset): | |
""" | |
Dataset for DPO optimization. | |
The data is assumed to be in the same format as for causal language modeling, | |
but an additional column with rejected answers is required. | |
For chained conversations, rejected answers are equal normal answers up to the | |
last answer. THe last answers are then different. | |
""" | |
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): | |
assert ( | |
cfg.dataset.limit_chained_samples | |
), "Need to enable limit_chained_samples for dpo training" | |
super().__init__(df=df, cfg=cfg, mode=mode) | |
with PatchedAttribute( | |
cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column | |
): | |
if cfg.dataset.rejected_prompt_column != "None": | |
with PatchedAttribute( | |
cfg.dataset, "prompt_column", cfg.dataset.rejected_prompt_column | |
): | |
self.conversation_chain_handler_rejected = ConversationChainHandler( | |
self.df, cfg | |
) | |
else: | |
self.conversation_chain_handler_rejected = ConversationChainHandler( | |
self.df, cfg | |
) | |
def __getitem__(self, idx: int) -> Dict: | |
"""Reads a single text observation.""" | |
chosen_sample = super().__getitem__(idx) | |
keys = ["input_ids", "attention_mask", "token_type_ids", "labels"] | |
prompt_keys = [ | |
"prompt_input_ids", | |
"prompt_attention_mask", | |
"prompt_token_type_ids", | |
] | |
prompt_sample = {k: v for k, v in chosen_sample.items() if k in prompt_keys} | |
chosen_sample = { | |
f"chosen_{k}": v for k, v in chosen_sample.items() if k in keys | |
} | |
with PatchedAttribute( | |
self, "conversation_chain_handler", self.conversation_chain_handler_rejected | |
): | |
rejected_sample = { | |
f"rejected_{k}": v | |
for k, v in super().__getitem__(idx).items() | |
if k in keys | |
} | |
sample = {**chosen_sample, **rejected_sample, **prompt_sample} | |
return sample | |
def get_labels(self, prompt_encodings, answer_encodings): | |
""" | |
Mask all but the last answer. | |
""" | |
labels = torch.cat( | |
[ | |
torch.cat( | |
[ | |
torch.full_like( | |
prompt_encoding, | |
fill_value=-100, | |
), | |
torch.full_like( | |
answer_encoding, | |
fill_value=-100, | |
), | |
] | |
) | |
for prompt_encoding, answer_encoding in zip( | |
prompt_encodings, answer_encodings | |
) | |
] | |
).clone() | |
if len(answer_encodings[-1]): | |
# empty answers would create a RuntimeError | |
labels[-len(answer_encodings[-1]) :] = answer_encodings[-1] | |
if self.cfg.dataset.add_eos_token_to_answer: | |
# eos_token may be equal to pad_token. Add the label back manually. | |
labels[-1] = self.tokenizer.eos_token_id | |
if self.cfg.tokenizer.max_length < len(labels): | |
labels = labels[-self.cfg.tokenizer.max_length :] | |
sample = dict(labels=torch.full((self.cfg.tokenizer.max_length,), -100)) | |
sample["labels"][-len(labels) :] = labels | |
return sample | |
def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"): | |
""" | |
Quick check whether Dataframe and configurations are correctly set. | |
""" | |
super().sanity_check(df=df, cfg=cfg, mode=mode) | |
assert cfg.dataset.rejected_answer_column in df.columns, ( | |
f"Answer column {cfg.dataset.rejected_answer_column} not found in the " | |
f"{mode} DataFrame." | |
) | |