tess-2-demo / sdlm /run_pretrain.py
hamishivi's picture
commit
17ff0d8 verified
import logging
import os
import sys
import datasets
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed
from transformers.trainer_callback import TrainerState
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from .arguments import get_args
from .data.data_collator import SpanInfillingDataCollator
from .data.data_utils import load_data, tokenize_data_new
from .inference.inference_utils import evaluate_generation
from .models import get_torch_dtype, load_model
from .schedulers import TokenWiseSimplexDDPMScheduler
from .trainers.trainer_diffusion import DiffusionTrainer
from .utils import (
get_last_checkpoint_with_beaker_preemption,
is_nfs_available,
is_weka_available,
resolve_last_checkpoint_vs_resume_from_checkpoint,
set_hf_home,
set_pretraining_dataset,
)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.25.0")
require_version(
"datasets>=2.0.0",
"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
)
logger = logging.getLogger(__name__)
# set environment variables
set_hf_home()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def filter_by_length(min_len: int, pad_token_id: int) -> bool:
"""hashable filter function for hf dataset library"""
def func(x):
return min_len <= len([i for i in x["input_ids"] if i != pad_token_id])
return func
def get_compute_metrics(data_args, training_args, model_args):
# Causal language model.
causal_model = AutoModelForCausalLM.from_pretrained(
model_args.autoregressive_eval_model,
torch_dtype=get_torch_dtype(training_args),
attn_implementation="flash_attention_2"
if model_args.use_flash_attention2
else "eager",
).to(training_args.device)
causal_tokenizer = AutoTokenizer.from_pretrained(
model_args.autoregressive_eval_model
)
is_conditional_generation = data_args.conditional_generation is not None
prefix_lm_eval = data_args.conditional_generation in [
"prefix_lm",
"ul2",
"ul2_with_unconditional",
"prefix_with_unconditional",
"ul2_variable",
]
compute_metrics = lambda results: evaluate_generation( # noqa: E731
results,
data_args,
causal_model,
causal_tokenizer,
is_conditional_generation,
prefix_lm_eval=prefix_lm_eval,
skip_special_tokens=data_args.skip_special_tokens,
eval_for_all_metrics=training_args.eval_for_all_metrics,
)
return compute_metrics
# so we evaluate on the first step, useful for checking training is working.
class EvaluateFirstStepCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if state.global_step == 1:
control.should_evaluate = True
def main():
# parse args
model_args, data_args, training_args, diffusion_args = get_args()
set_pretraining_dataset(data_args)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint.
last_checkpoint = get_last_checkpoint_with_beaker_preemption(training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# load model
tokenizer, model = load_model(
model_args, data_args, training_args, diffusion_args, logger
)
assert model.config.pad_token_id is not None
if training_args.do_train:
raw_datasets = load_data(data_args, model_args)
train_dataset = tokenize_data_new(
data_args, tokenizer, raw_datasets, training_args
)["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if data_args.min_train_seq_length != 0:
train_dataset = train_dataset.filter(
filter_by_length(
data_args.min_train_seq_length, model.config.pad_token_id
)
)
if data_args.shuffle and data_args.streaming:
train_dataset = train_dataset.shuffle(
seed=training_args.seed, buffer_size=10_000
)
elif data_args.shuffle:
train_dataset = train_dataset.shuffle(seed=training_args.seed)
if training_args.do_eval:
# default to c4
if is_weka_available():
data_file_path = "/data/input/jaket/c4_subset"
elif is_nfs_available():
data_file_path = (
"/net/nfs.cirrascale/allennlp/jaket/simplex-diffusion/c4_subset"
)
else:
# yale
data_file_path = "/home/jt856/documents/simplex-diffusion/raw/c4_subset"
c4_raw_dataset = datasets.IterableDatasetDict(
{
"validation": datasets.load_dataset(
"json",
data_files=os.path.join(
data_file_path, "c4-validation.00000-of-00008.json"
),
)["train"]
}
)
c4_tokenized_datasets = tokenize_data_new(
data_args, tokenizer, c4_raw_dataset, training_args
)
eval_dataset = c4_tokenized_datasets["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
if data_args.min_eval_seq_length != 0:
eval_dataset = eval_dataset.filter(
filter_by_length(
data_args.min_eval_seq_length, model.config.pad_token_id
),
num_proc=data_args.preprocessing_num_workers,
)
def preprocess_logits_for_metrics(logits):
return logits.argmax(dim=-1)
# Data collator
# TODO: fix lambda max_seq_length, extra_padding_ratio:
pad_to_multiple_of_8 = (
data_args.line_by_line
and training_args.fp16
and not data_args.pad_to_max_length
)
data_collator = lambda mode: SpanInfillingDataCollator( # noqa: E731
mode=mode,
data_args=data_args,
tokenizer=tokenizer,
padding="max_length" if data_args.pad_to_max_length else True,
max_length=data_args.max_seq_length,
seed=training_args.seed,
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
eval_context_size=data_args.eval_context_size,
)
compute_metrics = None
if training_args.do_eval and not training_args.without_compute_metrics:
# call only when necessary
compute_metrics = get_compute_metrics(data_args, training_args, model_args)
# init schedulers
noise_scheduler = TokenWiseSimplexDDPMScheduler(
num_train_timesteps=diffusion_args.num_diffusion_steps,
beta_schedule=diffusion_args.beta_schedule,
simplex_value=diffusion_args.simplex_value,
clip_sample=diffusion_args.clip_sample,
device=training_args.device,
multiply_factor=diffusion_args.multiply_factor,
)
inference_noise_schedulers = [
TokenWiseSimplexDDPMScheduler(
num_train_timesteps=timesteps,
beta_schedule=diffusion_args.beta_schedule,
simplex_value=diffusion_args.simplex_value,
clip_sample=diffusion_args.clip_sample,
device=training_args.device,
multiply_factor=diffusion_args.multiply_factor,
)
for timesteps in diffusion_args.num_inference_diffusion_steps
]
# Initialize our Trainer
trainer = DiffusionTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval
else None,
noise_scheduler=noise_scheduler,
diffusion_args=diffusion_args,
data_args=data_args,
inference_noise_schedulers=inference_noise_schedulers,
)
trainer.add_callback(EvaluateFirstStepCallback())
# Training
if training_args.do_train:
checkpoint = resolve_last_checkpoint_vs_resume_from_checkpoint(
last_checkpoint,
training_args.resume_from_checkpoint,
)
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
if training_args.do_eval:
if training_args.load_states_in_eval_from_model_path:
trainer._load_from_checkpoint(model_args.model_name_or_path)
trainer.state = TrainerState.load_from_json(
os.path.join(model_args.model_name_or_path, "trainer_state.json")
)
trainer._load_rng_state(model_args.model_name_or_path)
# np.save("weights.npy", model.vocab_to_hidden_dim_embed.weight.data.numpy())
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = (
data_args.max_eval_samples
if data_args.max_eval_samples is not None
else len(eval_dataset)
)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == "__main__":
main()