File size: 12,044 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""Defines the utilities used during the training/infernece of diffusion language models."""
import os
from typing import Callable, Iterable, List
from collections import defaultdict

import torch
import torch.nn.functional as F
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import logging

logger = logging.get_logger(__name__)


def join_texts(prefixes, sentences):
    """Joins prefixes to setences."""
    return [f"{prefix}{sentence}" for prefix, sentence in zip(prefixes, sentences)]


def convert_to_simplex(token_ids, simplex_value, vocab_size):
    return 2 * simplex_value * F.one_hot(token_ids, vocab_size) - simplex_value


def scale(inputs, scale_value):
    return inputs / scale_value


def tokenwise_timestep(position, timestep, max_length, max_timesteps):
    n_e, t_e = 2 * max_length, max_timesteps
    n_s = min(max(max_length - timestep, 0), max_length)
    t_s = min(max(timestep - max_length, 0), max_timesteps)
    token_timestep = ((t_e - t_s) / (n_e - n_s)) * (position - n_s) + t_s
    return round(min(max(0, token_timestep), max_timesteps))


def self_condition_preds(self_condition, logits, logits_projection=None):
    if self_condition in [
        "logits",
        "logits_addition",
        "logits_mean",
        "logits_max",
        "logits_multiply",
    ]:
        previous_pred = logits.detach()
    elif self_condition in [
        "logits_with_projection",
        "logits_with_projection_addition",
    ]:
        previous_pred = logits_projection(logits.detach())
    else:
        assert NotImplementedError(f"{self_condition} is not implemented.")
    return previous_pred


def mix_values_based_on_self_condition(self_condition_type, value_1, value_2):
    if self_condition_type in ["logits_with_projection_addition", "logits_addition"]:
        mixed_values = value_1 + value_2
    elif self_condition_type == "logits_mean":
        mixed_values = (value_1 + value_2) / 2.0
    elif self_condition_type == "logits_max":
        mixed_values = torch.max(value_1, value_2)
    elif self_condition_type == "logits_multiply":
        mixed_values = value_1 * value_2
    else:
        assert NotImplementedError
    return mixed_values


def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
    return list(map(f, x))


def pad_data(data_list, tokenizer):
    return tokenizer.pad({"input_ids": data_list}, padding=True)["input_ids"]

# from the open-instruct codebase.
# NOTE: this is only used for eval and ar training
def encode_with_messages_format_v1(
    example, tokenizer, max_seq_length, return_string=False, add_generation_prompt=False
):
    """
    Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
    We concatenate all messages with the roles as delimiters and tokenize them together.
    """
    # filter (open orca)
    messages = [
        message
        for message in example["messages"]
        if message["role"] in {"user", "assistant"}
    ]
    # we only take the first two messages, since multi-turn is a little more complex
    messages = messages[:2]

    if len(messages) == 0:
        raise ValueError("messages field is empty.")
    # quick sanity checks
    assert messages[0]["role"] == "user"

    def _concat_messages(messages):
        message_text = ""
        for message in messages:
            if message["role"] == "user":
                message_text += "<|user|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "assistant":
                message_text += (
                    "<|assistant|>\n"
                    + message["content"].strip()
                    + tokenizer.eos_token
                    + "\n"
                )
            else:
                raise ValueError("Invalid role: {}".format(message["role"]))
        return message_text

    example_text = tokenizer.bos_token + _concat_messages(messages).strip()
    if add_generation_prompt:
        example_text += "\n<|assistant|>\n"
    if return_string:
        return example_text
    tokenized_example = tokenizer(
        example_text,
        add_special_tokens=False,
        return_tensors="pt",
        max_length=max_seq_length,
        truncation=True,
    )
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()

    # mask the non-assistant part for avoiding loss
    for message_idx, message in enumerate(messages):
        if message["role"] != "assistant":
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer(
                    _concat_messages(messages[:message_idx]),
                    return_tensors="pt",
                    max_length=max_seq_length,
                    truncation=True,
                ).input_ids.shape[1]
            if (
                message_idx < len(messages) - 1
                and messages[message_idx + 1]["role"] == "assistant"
            ):
                # here we also ignore the role of the assistant
                messages_so_far = (
                    _concat_messages(messages[: message_idx + 1]) + "<|assistant|>\n"
                )
            else:
                messages_so_far = _concat_messages(messages[: message_idx + 1])
            message_end_idx = tokenizer(
                messages_so_far,
                return_tensors="pt",
                max_length=max_seq_length,
                truncation=True,
                add_special_tokens=False,
            ).input_ids.shape[1]
            # we replace with pad token id,
            labels[:, message_start_idx:message_end_idx] = -100

            if message_end_idx >= max_seq_length:
                break

    attention_mask = torch.ones_like(input_ids)
    return {
        "input_ids": input_ids.flatten(),
        "labels": labels.flatten(),
        "attention_mask": attention_mask.flatten(),
    }


# fixes some newline issues in v1
# NOTE: this is only used for training
def encode_with_messages_format_v2(
    messages,
    tokenizer,
    max_seq_length: int,
):
    """
    `encode_with_messages_format`, but with prefix-accumulating multiturn format
    ex) input_ids: (a1, b1, a2, b2, a3), labels: (b3)
    """
    # quick sanity checks
    if len(messages) == 0:
        raise ValueError("messages field is empty.")
    assert messages[0]["role"] == "user"
    assert messages[1]["role"] == "assistant"

    # double check tokenizer config
    assert tokenizer.add_bos_token
    assert not tokenizer.add_eos_token
    assert tokenizer.padding_side == "right"

    message_text = tokenizer.bos_token
    result = defaultdict(list)
    for message in messages:
        if message["role"] == "user":
            message_text += "<|user|>\n" + message["content"].strip() + "\n"
        elif message["role"] == "assistant":
            # tokenize message so far as context
            # add generation prompt to mask out from loss
            tokenized_context = tokenizer(
                message_text + "<|assistant|>\n",
                truncation=False,
                padding=False,
                add_special_tokens=False,
            )
            context_length = len(tokenized_context["input_ids"])

            if context_length >= max_seq_length:
                break

            # append label
            message_text += "<|assistant|>\n" + message["content"].strip()

            # tokenize full message text
            # add eos and pad
            tokenized_example = tokenizer(
                (message_text + tokenizer.eos_token).strip(),
                truncation=True,
                padding="max_length",
                max_length=max_seq_length,
                return_tensors="pt",
                add_special_tokens=False,
            )
            input_ids = tokenized_example["input_ids"].squeeze()
            labels = input_ids.clone()
            labels[:context_length] = -100
            result["input_ids"].append(input_ids)
            result["labels"].append(labels)

            # add newline for next turn
            message_text += "\n"

    if not result:
        return result
    result["input_ids"] = torch.stack(result["input_ids"])
    result["labels"] = torch.stack(result["labels"])
    return result


# batched version of encode_with_messages_format_v2
def encode_with_messages_format_v2_batch(
    batch,
    tokenizer,
    max_seq_length: int,
    is_tulu_pair: bool = False,
    is_tulu_multiturn: bool = False,
    is_tulu_sliding_window_multiturn: bool = False,
):
    result = {"input_ids": [], "labels": []}

    def _helper(messages):
        encoded = encode_with_messages_format_v2(
            messages=messages,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
        )
        for key, value in encoded.items():
            result[key].append(value)

    for messages in batch["messages"]:
        # filter (open orca)
        messages = [
            message for message in messages if message["role"] in {"user", "assistant"}
        ]
        if is_tulu_multiturn:
            _helper(messages)
        elif is_tulu_sliding_window_multiturn:
            for i in range(0, len(messages) - 1, 2):
                _helper(messages[i:])
        else:
            max_message_idx = len(messages) - 1 if is_tulu_pair else 2
            for i in range(0, max_message_idx, 2):
                _helper(messages[i : i + 2])
    if result["input_ids"]:
        result["input_ids"] = torch.cat(result["input_ids"], dim=0)
        result["labels"] = torch.cat(result["labels"], dim=0)
    return result



def get_last_checkpoint_with_beaker_preemption(training_args) -> str:
    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if (
            last_checkpoint is None
            and len(os.listdir(training_args.output_dir)) > 0
            and not training_args.beaker
        ):
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif (
            last_checkpoint is not None and training_args.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
    return last_checkpoint


def resolve_last_checkpoint_vs_resume_from_checkpoint(
    last_checkpoint, resume_from_checkpoint
):
    """
    Prioritizes last_checkpoint over resume_from_checkpoint.
    When a job configured with `resume_from_checkpoint` is preempted and restarted,
    it needs to start from the last checkpoint in the beaker dataset, not the checkpoint
    specified via `resume_from_checkpoint`; otherwise we lose all progress made in the previous job.
    """
    checkpoint = None
    if last_checkpoint is not None:
        checkpoint = last_checkpoint
    elif resume_from_checkpoint is not None:
        checkpoint = resume_from_checkpoint
    return checkpoint


def is_weka_available() -> bool:
    # assume mount path is /data/input
    # jupiter
    return os.path.isdir("/data/input")


def is_nfs_available() -> bool:
    # allennlp, a100, pluto
    return os.path.isdir("/net/nfs.cirrascale")


def set_hf_home() -> None:
    if is_weka_available():
        os.environ["HF_HOME"] = "/data/input/jaket/.hf"
    elif is_nfs_available():
        os.environ["HF_HOME"] = "/net/nfs.cirrascale/allennlp/jaket/.hf"


def set_pretraining_dataset(data_args) -> None:
    if is_weka_available():
        data_args.dataset_name = "sdlm/data/dolma/dolma_dataset.py"
    else:
        data_args.dataset_name = "emozilla/dolma-v1_7-305B"