File size: 6,301 Bytes
2ea70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7803f09
2ea70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple

from pydantic import BaseModel

from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates


class Message(BaseModel):
    """message/turn"""

    role: str
    content: str
    label: Optional[bool] = None


class MessageList(BaseModel):
    """conversation"""

    messages: List[Message]


def load(
    tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
):  # pylint: disable=possibly-unused-variable,unused-argument
    """
    chatml transforms for datasets with system, input, chosen, rejected
    """

    chat_template = chat_templates("chatml")
    if ds_cfg and "chat_template" in ds_cfg:
        chat_template = ds_cfg["chat_template"]
        try:
            chat_template = chat_templates(chat_template)
        except ValueError:
            pass
    tokenizer.chat_template = chat_template

    return ORPOTokenizingStrategy(
        ORPOPrompter(chat_template, tokenizer),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
        dataset_parser=ORPODatasetParsingStrategy(),
    )


class ORPODatasetParsingStrategy:
    """Strategy to parse chosen rejected dataset into messagelist"""

    def get_chosen_conversation_thread(self, prompt) -> MessageList:
        """Dataset structure mappings"""

        messages: List[Message] = []
        if system := prompt.get("system", None):
            messages.append(Message(role="system", content=system, label=False))
        messages.append(Message(role="user", content=prompt["prompt"], label=False))
        messages.append(
            Message(
                role="assistant", content=prompt["chosen"][1]["content"], label=True
            )
        )
        return MessageList(messages=messages)

    def get_rejected_conversation_thread(self, prompt) -> MessageList:
        """Dataset structure mappings"""

        messages: List[Message] = []
        if system := prompt.get("system", None):
            messages.append(Message(role="system", content=system, label=False))
        messages.append(Message(role="user", content=prompt["prompt"], label=False))
        messages.append(
            Message(
                role="assistant", content=prompt["rejected"][1]["content"], label=True
            )
        )
        return MessageList(messages=messages)


class ORPOTokenizingStrategy(PromptTokenizingStrategy):
    """
    rejected_input_ids
    input_ids
    rejected_attention_mask
    attention_mask
    rejected_labels
    labels
    """

    def __init__(
        self,
        *args,
        dataset_parser=None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.dataset_parser = dataset_parser

    def tokenize_prompt(self, prompt):
        # pass the rejected prompt/row to the Prompter to get the formatted prompt
        prompt_len = 0
        rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
            prompt
        )
        input_ids = []
        labels = []
        for _, (part, label) in enumerate(
            self.prompter.build_prompt(rejected_message_list)
        ):
            if not part:
                continue
            _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
            prev_idx = len(input_ids)
            input_ids += _input_ids[prev_idx:]
            if label:
                labels += input_ids[prev_idx:]
            else:
                labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
                prompt_len = len(input_ids)
        # remap the input_ids, attention_mask and labels
        rejected_input_ids = input_ids
        rejected_labels = labels
        # pass the chosen prompt/row to the Prompter to get the formatted prompt
        chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
        input_ids = []
        labels = []
        for _, (part, label) in enumerate(
            self.prompter.build_prompt(chosen_message_list)
        ):
            if not part:
                continue
            _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
            prev_idx = len(input_ids)
            input_ids += _input_ids[prev_idx:]
            if label:
                labels += input_ids[prev_idx:]
            else:
                labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)

        return {
            "rejected_input_ids": rejected_input_ids,
            "rejected_labels": rejected_labels,
            "rejected_attention_mask": [1] * len(rejected_labels),
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": [1] * len(labels),
            "prompt_attention_mask": [1] * prompt_len
            + [0] * (len(labels) - prompt_len),
        }


class ORPOPrompter(Prompter):
    """Single Turn prompter for ORPO"""

    def __init__(self, chat_template, tokenizer):
        self.chat_template = chat_template
        self.tokenizer = tokenizer

    def build_prompt(
        self,
        message_list: MessageList,
    ) -> Generator[Tuple[str, bool], None, None]:
        conversation = []
        for message in message_list.messages:
            conversation.append(message.model_dump())
            if message.role == "system":
                yield self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=False,
                    chat_template=self.chat_template,
                    tokenize=False,
                ), False
            if message.role == "user":
                yield self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=True,
                    chat_template=self.chat_template,
                    tokenize=False,
                ), False
            if message.role == "assistant":
                yield self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=False,
                    chat_template=self.chat_template,
                    tokenize=False,
                ), True