Spaces:
Sleeping
Sleeping
"""Defines the utilities used during the training/infernece of diffusion language models.""" | |
import os | |
from typing import Callable, Iterable, List | |
from collections import defaultdict | |
import torch | |
import torch.nn.functional as F | |
from transformers.trainer_utils import get_last_checkpoint | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
def join_texts(prefixes, sentences): | |
"""Joins prefixes to setences.""" | |
return [f"{prefix}{sentence}" for prefix, sentence in zip(prefixes, sentences)] | |
def convert_to_simplex(token_ids, simplex_value, vocab_size): | |
return 2 * simplex_value * F.one_hot(token_ids, vocab_size) - simplex_value | |
def scale(inputs, scale_value): | |
return inputs / scale_value | |
def tokenwise_timestep(position, timestep, max_length, max_timesteps): | |
n_e, t_e = 2 * max_length, max_timesteps | |
n_s = min(max(max_length - timestep, 0), max_length) | |
t_s = min(max(timestep - max_length, 0), max_timesteps) | |
token_timestep = ((t_e - t_s) / (n_e - n_s)) * (position - n_s) + t_s | |
return round(min(max(0, token_timestep), max_timesteps)) | |
def self_condition_preds(self_condition, logits, logits_projection=None): | |
if self_condition in [ | |
"logits", | |
"logits_addition", | |
"logits_mean", | |
"logits_max", | |
"logits_multiply", | |
]: | |
previous_pred = logits.detach() | |
elif self_condition in [ | |
"logits_with_projection", | |
"logits_with_projection_addition", | |
]: | |
previous_pred = logits_projection(logits.detach()) | |
else: | |
assert NotImplementedError(f"{self_condition} is not implemented.") | |
return previous_pred | |
def mix_values_based_on_self_condition(self_condition_type, value_1, value_2): | |
if self_condition_type in ["logits_with_projection_addition", "logits_addition"]: | |
mixed_values = value_1 + value_2 | |
elif self_condition_type == "logits_mean": | |
mixed_values = (value_1 + value_2) / 2.0 | |
elif self_condition_type == "logits_max": | |
mixed_values = torch.max(value_1, value_2) | |
elif self_condition_type == "logits_multiply": | |
mixed_values = value_1 * value_2 | |
else: | |
assert NotImplementedError | |
return mixed_values | |
def lmap(f: Callable, x: Iterable) -> List: | |
"""list(map(f, x))""" | |
return list(map(f, x)) | |
def pad_data(data_list, tokenizer): | |
return tokenizer.pad({"input_ids": data_list}, padding=True)["input_ids"] | |
# from the open-instruct codebase. | |
# NOTE: this is only used for eval and ar training | |
def encode_with_messages_format_v1( | |
example, tokenizer, max_seq_length, return_string=False, add_generation_prompt=False | |
): | |
""" | |
Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields. | |
We concatenate all messages with the roles as delimiters and tokenize them together. | |
""" | |
# filter (open orca) | |
messages = [ | |
message | |
for message in example["messages"] | |
if message["role"] in {"user", "assistant"} | |
] | |
# we only take the first two messages, since multi-turn is a little more complex | |
messages = messages[:2] | |
if len(messages) == 0: | |
raise ValueError("messages field is empty.") | |
# quick sanity checks | |
assert messages[0]["role"] == "user" | |
def _concat_messages(messages): | |
message_text = "" | |
for message in messages: | |
if message["role"] == "user": | |
message_text += "<|user|>\n" + message["content"].strip() + "\n" | |
elif message["role"] == "assistant": | |
message_text += ( | |
"<|assistant|>\n" | |
+ message["content"].strip() | |
+ tokenizer.eos_token | |
+ "\n" | |
) | |
else: | |
raise ValueError("Invalid role: {}".format(message["role"])) | |
return message_text | |
example_text = tokenizer.bos_token + _concat_messages(messages).strip() | |
if add_generation_prompt: | |
example_text += "\n<|assistant|>\n" | |
if return_string: | |
return example_text | |
tokenized_example = tokenizer( | |
example_text, | |
add_special_tokens=False, | |
return_tensors="pt", | |
max_length=max_seq_length, | |
truncation=True, | |
) | |
input_ids = tokenized_example.input_ids | |
labels = input_ids.clone() | |
# mask the non-assistant part for avoiding loss | |
for message_idx, message in enumerate(messages): | |
if message["role"] != "assistant": | |
if message_idx == 0: | |
message_start_idx = 0 | |
else: | |
message_start_idx = tokenizer( | |
_concat_messages(messages[:message_idx]), | |
return_tensors="pt", | |
max_length=max_seq_length, | |
truncation=True, | |
).input_ids.shape[1] | |
if ( | |
message_idx < len(messages) - 1 | |
and messages[message_idx + 1]["role"] == "assistant" | |
): | |
# here we also ignore the role of the assistant | |
messages_so_far = ( | |
_concat_messages(messages[: message_idx + 1]) + "<|assistant|>\n" | |
) | |
else: | |
messages_so_far = _concat_messages(messages[: message_idx + 1]) | |
message_end_idx = tokenizer( | |
messages_so_far, | |
return_tensors="pt", | |
max_length=max_seq_length, | |
truncation=True, | |
add_special_tokens=False, | |
).input_ids.shape[1] | |
# we replace with pad token id, | |
labels[:, message_start_idx:message_end_idx] = -100 | |
if message_end_idx >= max_seq_length: | |
break | |
attention_mask = torch.ones_like(input_ids) | |
return { | |
"input_ids": input_ids.flatten(), | |
"labels": labels.flatten(), | |
"attention_mask": attention_mask.flatten(), | |
} | |
# fixes some newline issues in v1 | |
# NOTE: this is only used for training | |
def encode_with_messages_format_v2( | |
messages, | |
tokenizer, | |
max_seq_length: int, | |
): | |
""" | |
`encode_with_messages_format`, but with prefix-accumulating multiturn format | |
ex) input_ids: (a1, b1, a2, b2, a3), labels: (b3) | |
""" | |
# quick sanity checks | |
if len(messages) == 0: | |
raise ValueError("messages field is empty.") | |
assert messages[0]["role"] == "user" | |
assert messages[1]["role"] == "assistant" | |
# double check tokenizer config | |
assert tokenizer.add_bos_token | |
assert not tokenizer.add_eos_token | |
assert tokenizer.padding_side == "right" | |
message_text = tokenizer.bos_token | |
result = defaultdict(list) | |
for message in messages: | |
if message["role"] == "user": | |
message_text += "<|user|>\n" + message["content"].strip() + "\n" | |
elif message["role"] == "assistant": | |
# tokenize message so far as context | |
# add generation prompt to mask out from loss | |
tokenized_context = tokenizer( | |
message_text + "<|assistant|>\n", | |
truncation=False, | |
padding=False, | |
add_special_tokens=False, | |
) | |
context_length = len(tokenized_context["input_ids"]) | |
if context_length >= max_seq_length: | |
break | |
# append label | |
message_text += "<|assistant|>\n" + message["content"].strip() | |
# tokenize full message text | |
# add eos and pad | |
tokenized_example = tokenizer( | |
(message_text + tokenizer.eos_token).strip(), | |
truncation=True, | |
padding="max_length", | |
max_length=max_seq_length, | |
return_tensors="pt", | |
add_special_tokens=False, | |
) | |
input_ids = tokenized_example["input_ids"].squeeze() | |
labels = input_ids.clone() | |
labels[:context_length] = -100 | |
result["input_ids"].append(input_ids) | |
result["labels"].append(labels) | |
# add newline for next turn | |
message_text += "\n" | |
if not result: | |
return result | |
result["input_ids"] = torch.stack(result["input_ids"]) | |
result["labels"] = torch.stack(result["labels"]) | |
return result | |
# batched version of encode_with_messages_format_v2 | |
def encode_with_messages_format_v2_batch( | |
batch, | |
tokenizer, | |
max_seq_length: int, | |
is_tulu_pair: bool = False, | |
is_tulu_multiturn: bool = False, | |
is_tulu_sliding_window_multiturn: bool = False, | |
): | |
result = {"input_ids": [], "labels": []} | |
def _helper(messages): | |
encoded = encode_with_messages_format_v2( | |
messages=messages, | |
tokenizer=tokenizer, | |
max_seq_length=max_seq_length, | |
) | |
for key, value in encoded.items(): | |
result[key].append(value) | |
for messages in batch["messages"]: | |
# filter (open orca) | |
messages = [ | |
message for message in messages if message["role"] in {"user", "assistant"} | |
] | |
if is_tulu_multiturn: | |
_helper(messages) | |
elif is_tulu_sliding_window_multiturn: | |
for i in range(0, len(messages) - 1, 2): | |
_helper(messages[i:]) | |
else: | |
max_message_idx = len(messages) - 1 if is_tulu_pair else 2 | |
for i in range(0, max_message_idx, 2): | |
_helper(messages[i : i + 2]) | |
if result["input_ids"]: | |
result["input_ids"] = torch.cat(result["input_ids"], dim=0) | |
result["labels"] = torch.cat(result["labels"], dim=0) | |
return result | |
def get_last_checkpoint_with_beaker_preemption(training_args) -> str: | |
last_checkpoint = None | |
if ( | |
os.path.isdir(training_args.output_dir) | |
and training_args.do_train | |
and not training_args.overwrite_output_dir | |
): | |
last_checkpoint = get_last_checkpoint(training_args.output_dir) | |
if ( | |
last_checkpoint is None | |
and len(os.listdir(training_args.output_dir)) > 0 | |
and not training_args.beaker | |
): | |
raise ValueError( | |
f"Output directory ({training_args.output_dir}) already exists and is not empty. " | |
"Use --overwrite_output_dir to overcome." | |
) | |
elif ( | |
last_checkpoint is not None and training_args.resume_from_checkpoint is None | |
): | |
logger.info( | |
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " | |
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." | |
) | |
return last_checkpoint | |
def resolve_last_checkpoint_vs_resume_from_checkpoint( | |
last_checkpoint, resume_from_checkpoint | |
): | |
""" | |
Prioritizes last_checkpoint over resume_from_checkpoint. | |
When a job configured with `resume_from_checkpoint` is preempted and restarted, | |
it needs to start from the last checkpoint in the beaker dataset, not the checkpoint | |
specified via `resume_from_checkpoint`; otherwise we lose all progress made in the previous job. | |
""" | |
checkpoint = None | |
if last_checkpoint is not None: | |
checkpoint = last_checkpoint | |
elif resume_from_checkpoint is not None: | |
checkpoint = resume_from_checkpoint | |
return checkpoint | |
def is_weka_available() -> bool: | |
# assume mount path is /data/input | |
# jupiter | |
return os.path.isdir("/data/input") | |
def is_nfs_available() -> bool: | |
# allennlp, a100, pluto | |
return os.path.isdir("/net/nfs.cirrascale") | |
def set_hf_home() -> None: | |
if is_weka_available(): | |
os.environ["HF_HOME"] = "/data/input/jaket/.hf" | |
elif is_nfs_available(): | |
os.environ["HF_HOME"] = "/net/nfs.cirrascale/allennlp/jaket/.hf" | |
def set_pretraining_dataset(data_args) -> None: | |
if is_weka_available(): | |
data_args.dataset_name = "sdlm/data/dolma/dolma_dataset.py" | |
else: | |
data_args.dataset_name = "emozilla/dolma-v1_7-305B" | |