File size: 4,422 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import logging
from typing import Any, Dict

import pandas as pd
import torch

import llm_studio.src.datasets.text_causal_language_modeling_ds as text_causal_language_modeling_ds  # noqa: [F401]
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.utils.utils import PatchedAttribute

logger = logging.getLogger(__name__)


class CustomDataset(text_causal_language_modeling_ds.CustomDataset):
    """
    Dataset for DPO optimization.
    The data is assumed to be in the same format as for causal language modeling,
    but an additional column with rejected answers is required.
    For chained conversations, rejected answers are equal normal answers up to the
    last answer. THe last answers are then different.
    """

    def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
        assert (
            cfg.dataset.limit_chained_samples
        ), "Need to enable limit_chained_samples for dpo training"
        super().__init__(df=df, cfg=cfg, mode=mode)

        with PatchedAttribute(
            cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column
        ):
            if cfg.dataset.rejected_prompt_column != "None":
                with PatchedAttribute(
                    cfg.dataset, "prompt_column", cfg.dataset.rejected_prompt_column
                ):
                    self.conversation_chain_handler_rejected = ConversationChainHandler(
                        self.df, cfg
                    )
            else:
                self.conversation_chain_handler_rejected = ConversationChainHandler(
                    self.df, cfg
                )

    def __getitem__(self, idx: int) -> Dict:
        """Reads a single text observation."""
        chosen_sample = super().__getitem__(idx)
        keys = ["input_ids", "attention_mask", "token_type_ids", "labels"]
        prompt_keys = [
            "prompt_input_ids",
            "prompt_attention_mask",
            "prompt_token_type_ids",
        ]
        prompt_sample = {k: v for k, v in chosen_sample.items() if k in prompt_keys}
        chosen_sample = {
            f"chosen_{k}": v for k, v in chosen_sample.items() if k in keys
        }

        with PatchedAttribute(
            self, "conversation_chain_handler", self.conversation_chain_handler_rejected
        ):
            rejected_sample = {
                f"rejected_{k}": v
                for k, v in super().__getitem__(idx).items()
                if k in keys
            }

        sample = {**chosen_sample, **rejected_sample, **prompt_sample}
        return sample

    def get_labels(self, prompt_encodings, answer_encodings):
        """
        Mask all but the last answer.
        """
        labels = torch.cat(
            [
                torch.cat(
                    [
                        torch.full_like(
                            prompt_encoding,
                            fill_value=-100,
                        ),
                        torch.full_like(
                            answer_encoding,
                            fill_value=-100,
                        ),
                    ]
                )
                for prompt_encoding, answer_encoding in zip(
                    prompt_encodings, answer_encodings
                )
            ]
        ).clone()

        if len(answer_encodings[-1]):
            # empty answers would create a RuntimeError
            labels[-len(answer_encodings[-1]) :] = answer_encodings[-1]

        if self.cfg.dataset.add_eos_token_to_answer:
            # eos_token may be equal to pad_token. Add the label back manually.
            labels[-1] = self.tokenizer.eos_token_id
        if self.cfg.tokenizer.max_length < len(labels):
            labels = labels[-self.cfg.tokenizer.max_length :]

        sample = dict(labels=torch.full((self.cfg.tokenizer.max_length,), -100))
        sample["labels"][-len(labels) :] = labels
        return sample

    @classmethod
    def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):
        """
        Quick check whether Dataframe and configurations are correctly set.
        """
        super().sanity_check(df=df, cfg=cfg, mode=mode)
        assert cfg.dataset.rejected_answer_column in df.columns, (
            f"Answer column {cfg.dataset.rejected_answer_column} not found in the "
            f"{mode} DataFrame."
        )