import logging from typing import Any, Dict import pandas as pd import torch import llm_studio.src.datasets.text_causal_language_modeling_ds as text_causal_language_modeling_ds # noqa: [F401] from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler from llm_studio.src.utils.utils import PatchedAttribute logger = logging.getLogger(__name__) class CustomDataset(text_causal_language_modeling_ds.CustomDataset): """ Dataset for DPO optimization. The data is assumed to be in the same format as for causal language modeling, but an additional column with rejected answers is required. For chained conversations, rejected answers are equal normal answers up to the last answer. THe last answers are then different. """ def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): assert ( cfg.dataset.limit_chained_samples ), "Need to enable limit_chained_samples for dpo training" super().__init__(df=df, cfg=cfg, mode=mode) with PatchedAttribute( cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column ): if cfg.dataset.rejected_prompt_column != "None": with PatchedAttribute( cfg.dataset, "prompt_column", cfg.dataset.rejected_prompt_column ): self.conversation_chain_handler_rejected = ConversationChainHandler( self.df, cfg ) else: self.conversation_chain_handler_rejected = ConversationChainHandler( self.df, cfg ) def __getitem__(self, idx: int) -> Dict: """Reads a single text observation.""" chosen_sample = super().__getitem__(idx) keys = ["input_ids", "attention_mask", "token_type_ids", "labels"] prompt_keys = [ "prompt_input_ids", "prompt_attention_mask", "prompt_token_type_ids", ] prompt_sample = {k: v for k, v in chosen_sample.items() if k in prompt_keys} chosen_sample = { f"chosen_{k}": v for k, v in chosen_sample.items() if k in keys } with PatchedAttribute( self, "conversation_chain_handler", self.conversation_chain_handler_rejected ): rejected_sample = { f"rejected_{k}": v for k, v in super().__getitem__(idx).items() if k in keys } sample = {**chosen_sample, **rejected_sample, **prompt_sample} return sample def get_labels(self, prompt_encodings, answer_encodings): """ Mask all but the last answer. """ labels = torch.cat( [ torch.cat( [ torch.full_like( prompt_encoding, fill_value=-100, ), torch.full_like( answer_encoding, fill_value=-100, ), ] ) for prompt_encoding, answer_encoding in zip( prompt_encodings, answer_encodings ) ] ).clone() if len(answer_encodings[-1]): # empty answers would create a RuntimeError labels[-len(answer_encodings[-1]) :] = answer_encodings[-1] if self.cfg.dataset.add_eos_token_to_answer: # eos_token may be equal to pad_token. Add the label back manually. labels[-1] = self.tokenizer.eos_token_id if self.cfg.tokenizer.max_length < len(labels): labels = labels[-self.cfg.tokenizer.max_length :] sample = dict(labels=torch.full((self.cfg.tokenizer.max_length,), -100)) sample["labels"][-len(labels) :] = labels return sample @classmethod def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"): """ Quick check whether Dataframe and configurations are correctly set. """ super().sanity_check(df=df, cfg=cfg, mode=mode) assert cfg.dataset.rejected_answer_column in df.columns, ( f"Answer column {cfg.dataset.rejected_answer_column} not found in the " f"{mode} DataFrame." )