tess-2-demo / sdlm /utils.py
hamishivi's picture
commit
17ff0d8 verified
"""Defines the utilities used during the training/infernece of diffusion language models."""
import os
from typing import Callable, Iterable, List
from collections import defaultdict
import torch
import torch.nn.functional as F
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import logging
logger = logging.get_logger(__name__)
def join_texts(prefixes, sentences):
"""Joins prefixes to setences."""
return [f"{prefix}{sentence}" for prefix, sentence in zip(prefixes, sentences)]
def convert_to_simplex(token_ids, simplex_value, vocab_size):
return 2 * simplex_value * F.one_hot(token_ids, vocab_size) - simplex_value
def scale(inputs, scale_value):
return inputs / scale_value
def tokenwise_timestep(position, timestep, max_length, max_timesteps):
n_e, t_e = 2 * max_length, max_timesteps
n_s = min(max(max_length - timestep, 0), max_length)
t_s = min(max(timestep - max_length, 0), max_timesteps)
token_timestep = ((t_e - t_s) / (n_e - n_s)) * (position - n_s) + t_s
return round(min(max(0, token_timestep), max_timesteps))
def self_condition_preds(self_condition, logits, logits_projection=None):
if self_condition in [
"logits",
"logits_addition",
"logits_mean",
"logits_max",
"logits_multiply",
]:
previous_pred = logits.detach()
elif self_condition in [
"logits_with_projection",
"logits_with_projection_addition",
]:
previous_pred = logits_projection(logits.detach())
else:
assert NotImplementedError(f"{self_condition} is not implemented.")
return previous_pred
def mix_values_based_on_self_condition(self_condition_type, value_1, value_2):
if self_condition_type in ["logits_with_projection_addition", "logits_addition"]:
mixed_values = value_1 + value_2
elif self_condition_type == "logits_mean":
mixed_values = (value_1 + value_2) / 2.0
elif self_condition_type == "logits_max":
mixed_values = torch.max(value_1, value_2)
elif self_condition_type == "logits_multiply":
mixed_values = value_1 * value_2
else:
assert NotImplementedError
return mixed_values
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def pad_data(data_list, tokenizer):
return tokenizer.pad({"input_ids": data_list}, padding=True)["input_ids"]
# from the open-instruct codebase.
# NOTE: this is only used for eval and ar training
def encode_with_messages_format_v1(
example, tokenizer, max_seq_length, return_string=False, add_generation_prompt=False
):
"""
Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
We concatenate all messages with the roles as delimiters and tokenize them together.
"""
# filter (open orca)
messages = [
message
for message in example["messages"]
if message["role"] in {"user", "assistant"}
]
# we only take the first two messages, since multi-turn is a little more complex
messages = messages[:2]
if len(messages) == 0:
raise ValueError("messages field is empty.")
# quick sanity checks
assert messages[0]["role"] == "user"
def _concat_messages(messages):
message_text = ""
for message in messages:
if message["role"] == "user":
message_text += "<|user|>\n" + message["content"].strip() + "\n"
elif message["role"] == "assistant":
message_text += (
"<|assistant|>\n"
+ message["content"].strip()
+ tokenizer.eos_token
+ "\n"
)
else:
raise ValueError("Invalid role: {}".format(message["role"]))
return message_text
example_text = tokenizer.bos_token + _concat_messages(messages).strip()
if add_generation_prompt:
example_text += "\n<|assistant|>\n"
if return_string:
return example_text
tokenized_example = tokenizer(
example_text,
add_special_tokens=False,
return_tensors="pt",
max_length=max_seq_length,
truncation=True,
)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
# mask the non-assistant part for avoiding loss
for message_idx, message in enumerate(messages):
if message["role"] != "assistant":
if message_idx == 0:
message_start_idx = 0
else:
message_start_idx = tokenizer(
_concat_messages(messages[:message_idx]),
return_tensors="pt",
max_length=max_seq_length,
truncation=True,
).input_ids.shape[1]
if (
message_idx < len(messages) - 1
and messages[message_idx + 1]["role"] == "assistant"
):
# here we also ignore the role of the assistant
messages_so_far = (
_concat_messages(messages[: message_idx + 1]) + "<|assistant|>\n"
)
else:
messages_so_far = _concat_messages(messages[: message_idx + 1])
message_end_idx = tokenizer(
messages_so_far,
return_tensors="pt",
max_length=max_seq_length,
truncation=True,
add_special_tokens=False,
).input_ids.shape[1]
# we replace with pad token id,
labels[:, message_start_idx:message_end_idx] = -100
if message_end_idx >= max_seq_length:
break
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids.flatten(),
"labels": labels.flatten(),
"attention_mask": attention_mask.flatten(),
}
# fixes some newline issues in v1
# NOTE: this is only used for training
def encode_with_messages_format_v2(
messages,
tokenizer,
max_seq_length: int,
):
"""
`encode_with_messages_format`, but with prefix-accumulating multiturn format
ex) input_ids: (a1, b1, a2, b2, a3), labels: (b3)
"""
# quick sanity checks
if len(messages) == 0:
raise ValueError("messages field is empty.")
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
# double check tokenizer config
assert tokenizer.add_bos_token
assert not tokenizer.add_eos_token
assert tokenizer.padding_side == "right"
message_text = tokenizer.bos_token
result = defaultdict(list)
for message in messages:
if message["role"] == "user":
message_text += "<|user|>\n" + message["content"].strip() + "\n"
elif message["role"] == "assistant":
# tokenize message so far as context
# add generation prompt to mask out from loss
tokenized_context = tokenizer(
message_text + "<|assistant|>\n",
truncation=False,
padding=False,
add_special_tokens=False,
)
context_length = len(tokenized_context["input_ids"])
if context_length >= max_seq_length:
break
# append label
message_text += "<|assistant|>\n" + message["content"].strip()
# tokenize full message text
# add eos and pad
tokenized_example = tokenizer(
(message_text + tokenizer.eos_token).strip(),
truncation=True,
padding="max_length",
max_length=max_seq_length,
return_tensors="pt",
add_special_tokens=False,
)
input_ids = tokenized_example["input_ids"].squeeze()
labels = input_ids.clone()
labels[:context_length] = -100
result["input_ids"].append(input_ids)
result["labels"].append(labels)
# add newline for next turn
message_text += "\n"
if not result:
return result
result["input_ids"] = torch.stack(result["input_ids"])
result["labels"] = torch.stack(result["labels"])
return result
# batched version of encode_with_messages_format_v2
def encode_with_messages_format_v2_batch(
batch,
tokenizer,
max_seq_length: int,
is_tulu_pair: bool = False,
is_tulu_multiturn: bool = False,
is_tulu_sliding_window_multiturn: bool = False,
):
result = {"input_ids": [], "labels": []}
def _helper(messages):
encoded = encode_with_messages_format_v2(
messages=messages,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
)
for key, value in encoded.items():
result[key].append(value)
for messages in batch["messages"]:
# filter (open orca)
messages = [
message for message in messages if message["role"] in {"user", "assistant"}
]
if is_tulu_multiturn:
_helper(messages)
elif is_tulu_sliding_window_multiturn:
for i in range(0, len(messages) - 1, 2):
_helper(messages[i:])
else:
max_message_idx = len(messages) - 1 if is_tulu_pair else 2
for i in range(0, max_message_idx, 2):
_helper(messages[i : i + 2])
if result["input_ids"]:
result["input_ids"] = torch.cat(result["input_ids"], dim=0)
result["labels"] = torch.cat(result["labels"], dim=0)
return result
def get_last_checkpoint_with_beaker_preemption(training_args) -> str:
last_checkpoint = None
if (
os.path.isdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if (
last_checkpoint is None
and len(os.listdir(training_args.output_dir)) > 0
and not training_args.beaker
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif (
last_checkpoint is not None and training_args.resume_from_checkpoint is None
):
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
return last_checkpoint
def resolve_last_checkpoint_vs_resume_from_checkpoint(
last_checkpoint, resume_from_checkpoint
):
"""
Prioritizes last_checkpoint over resume_from_checkpoint.
When a job configured with `resume_from_checkpoint` is preempted and restarted,
it needs to start from the last checkpoint in the beaker dataset, not the checkpoint
specified via `resume_from_checkpoint`; otherwise we lose all progress made in the previous job.
"""
checkpoint = None
if last_checkpoint is not None:
checkpoint = last_checkpoint
elif resume_from_checkpoint is not None:
checkpoint = resume_from_checkpoint
return checkpoint
def is_weka_available() -> bool:
# assume mount path is /data/input
# jupiter
return os.path.isdir("/data/input")
def is_nfs_available() -> bool:
# allennlp, a100, pluto
return os.path.isdir("/net/nfs.cirrascale")
def set_hf_home() -> None:
if is_weka_available():
os.environ["HF_HOME"] = "/data/input/jaket/.hf"
elif is_nfs_available():
os.environ["HF_HOME"] = "/net/nfs.cirrascale/allennlp/jaket/.hf"
def set_pretraining_dataset(data_args) -> None:
if is_weka_available():
data_args.dataset_name = "sdlm/data/dolma/dolma_dataset.py"
else:
data_args.dataset_name = "emozilla/dolma-v1_7-305B"