Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import sys | |
import datasets | |
import transformers | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed | |
from transformers.trainer_callback import TrainerState | |
from transformers.utils import check_min_version | |
from transformers.utils.versions import require_version | |
from .arguments import get_args | |
from .data.data_collator import SpanInfillingDataCollator | |
from .data.data_utils import load_data, tokenize_data_new | |
from .inference.inference_utils import evaluate_generation | |
from .models import get_torch_dtype, load_model | |
from .schedulers import TokenWiseSimplexDDPMScheduler | |
from .trainers.trainer_diffusion import DiffusionTrainer | |
from .utils import ( | |
get_last_checkpoint_with_beaker_preemption, | |
is_nfs_available, | |
is_weka_available, | |
resolve_last_checkpoint_vs_resume_from_checkpoint, | |
set_hf_home, | |
set_pretraining_dataset, | |
) | |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. | |
check_min_version("4.25.0") | |
require_version( | |
"datasets>=2.0.0", | |
"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt", | |
) | |
logger = logging.getLogger(__name__) | |
# set environment variables | |
set_hf_home() | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def filter_by_length(min_len: int, pad_token_id: int) -> bool: | |
"""hashable filter function for hf dataset library""" | |
def func(x): | |
return min_len <= len([i for i in x["input_ids"] if i != pad_token_id]) | |
return func | |
def get_compute_metrics(data_args, training_args, model_args): | |
# Causal language model. | |
causal_model = AutoModelForCausalLM.from_pretrained( | |
model_args.autoregressive_eval_model, | |
torch_dtype=get_torch_dtype(training_args), | |
attn_implementation="flash_attention_2" | |
if model_args.use_flash_attention2 | |
else "eager", | |
).to(training_args.device) | |
causal_tokenizer = AutoTokenizer.from_pretrained( | |
model_args.autoregressive_eval_model | |
) | |
is_conditional_generation = data_args.conditional_generation is not None | |
prefix_lm_eval = data_args.conditional_generation in [ | |
"prefix_lm", | |
"ul2", | |
"ul2_with_unconditional", | |
"prefix_with_unconditional", | |
"ul2_variable", | |
] | |
compute_metrics = lambda results: evaluate_generation( # noqa: E731 | |
results, | |
data_args, | |
causal_model, | |
causal_tokenizer, | |
is_conditional_generation, | |
prefix_lm_eval=prefix_lm_eval, | |
skip_special_tokens=data_args.skip_special_tokens, | |
eval_for_all_metrics=training_args.eval_for_all_metrics, | |
) | |
return compute_metrics | |
# so we evaluate on the first step, useful for checking training is working. | |
class EvaluateFirstStepCallback(TrainerCallback): | |
def on_step_end(self, args, state, control, **kwargs): | |
if state.global_step == 1: | |
control.should_evaluate = True | |
def main(): | |
# parse args | |
model_args, data_args, training_args, diffusion_args = get_args() | |
set_pretraining_dataset(data_args) | |
# Setup logging | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
log_level = training_args.get_process_log_level() | |
logger.setLevel(log_level) | |
datasets.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.enable_default_handler() | |
transformers.utils.logging.enable_explicit_format() | |
# Log on each process the small summary: | |
logger.warning( | |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" | |
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
) | |
# Set the verbosity to info of the Transformers logger (on main process only): | |
logger.info(f"Training/evaluation parameters {training_args}") | |
# Detecting last checkpoint. | |
last_checkpoint = get_last_checkpoint_with_beaker_preemption(training_args) | |
# Set seed before initializing model. | |
set_seed(training_args.seed) | |
# load model | |
tokenizer, model = load_model( | |
model_args, data_args, training_args, diffusion_args, logger | |
) | |
assert model.config.pad_token_id is not None | |
if training_args.do_train: | |
raw_datasets = load_data(data_args, model_args) | |
train_dataset = tokenize_data_new( | |
data_args, tokenizer, raw_datasets, training_args | |
)["train"] | |
if data_args.max_train_samples is not None: | |
max_train_samples = min(len(train_dataset), data_args.max_train_samples) | |
train_dataset = train_dataset.select(range(max_train_samples)) | |
if data_args.min_train_seq_length != 0: | |
train_dataset = train_dataset.filter( | |
filter_by_length( | |
data_args.min_train_seq_length, model.config.pad_token_id | |
) | |
) | |
if data_args.shuffle and data_args.streaming: | |
train_dataset = train_dataset.shuffle( | |
seed=training_args.seed, buffer_size=10_000 | |
) | |
elif data_args.shuffle: | |
train_dataset = train_dataset.shuffle(seed=training_args.seed) | |
if training_args.do_eval: | |
# default to c4 | |
if is_weka_available(): | |
data_file_path = "/data/input/jaket/c4_subset" | |
elif is_nfs_available(): | |
data_file_path = ( | |
"/net/nfs.cirrascale/allennlp/jaket/simplex-diffusion/c4_subset" | |
) | |
else: | |
# yale | |
data_file_path = "/home/jt856/documents/simplex-diffusion/raw/c4_subset" | |
c4_raw_dataset = datasets.IterableDatasetDict( | |
{ | |
"validation": datasets.load_dataset( | |
"json", | |
data_files=os.path.join( | |
data_file_path, "c4-validation.00000-of-00008.json" | |
), | |
)["train"] | |
} | |
) | |
c4_tokenized_datasets = tokenize_data_new( | |
data_args, tokenizer, c4_raw_dataset, training_args | |
) | |
eval_dataset = c4_tokenized_datasets["validation"] | |
if data_args.max_eval_samples is not None: | |
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) | |
eval_dataset = eval_dataset.select(range(max_eval_samples)) | |
if data_args.min_eval_seq_length != 0: | |
eval_dataset = eval_dataset.filter( | |
filter_by_length( | |
data_args.min_eval_seq_length, model.config.pad_token_id | |
), | |
num_proc=data_args.preprocessing_num_workers, | |
) | |
def preprocess_logits_for_metrics(logits): | |
return logits.argmax(dim=-1) | |
# Data collator | |
# TODO: fix lambda max_seq_length, extra_padding_ratio: | |
pad_to_multiple_of_8 = ( | |
data_args.line_by_line | |
and training_args.fp16 | |
and not data_args.pad_to_max_length | |
) | |
data_collator = lambda mode: SpanInfillingDataCollator( # noqa: E731 | |
mode=mode, | |
data_args=data_args, | |
tokenizer=tokenizer, | |
padding="max_length" if data_args.pad_to_max_length else True, | |
max_length=data_args.max_seq_length, | |
seed=training_args.seed, | |
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, | |
eval_context_size=data_args.eval_context_size, | |
) | |
compute_metrics = None | |
if training_args.do_eval and not training_args.without_compute_metrics: | |
# call only when necessary | |
compute_metrics = get_compute_metrics(data_args, training_args, model_args) | |
# init schedulers | |
noise_scheduler = TokenWiseSimplexDDPMScheduler( | |
num_train_timesteps=diffusion_args.num_diffusion_steps, | |
beta_schedule=diffusion_args.beta_schedule, | |
simplex_value=diffusion_args.simplex_value, | |
clip_sample=diffusion_args.clip_sample, | |
device=training_args.device, | |
multiply_factor=diffusion_args.multiply_factor, | |
) | |
inference_noise_schedulers = [ | |
TokenWiseSimplexDDPMScheduler( | |
num_train_timesteps=timesteps, | |
beta_schedule=diffusion_args.beta_schedule, | |
simplex_value=diffusion_args.simplex_value, | |
clip_sample=diffusion_args.clip_sample, | |
device=training_args.device, | |
multiply_factor=diffusion_args.multiply_factor, | |
) | |
for timesteps in diffusion_args.num_inference_diffusion_steps | |
] | |
# Initialize our Trainer | |
trainer = DiffusionTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset if training_args.do_train else None, | |
eval_dataset=eval_dataset if training_args.do_eval else None, | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
compute_metrics=compute_metrics, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics | |
if training_args.do_eval | |
else None, | |
noise_scheduler=noise_scheduler, | |
diffusion_args=diffusion_args, | |
data_args=data_args, | |
inference_noise_schedulers=inference_noise_schedulers, | |
) | |
trainer.add_callback(EvaluateFirstStepCallback()) | |
# Training | |
if training_args.do_train: | |
checkpoint = resolve_last_checkpoint_vs_resume_from_checkpoint( | |
last_checkpoint, | |
training_args.resume_from_checkpoint, | |
) | |
train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
trainer.save_model() # Saves the tokenizer too for easy upload | |
metrics = train_result.metrics | |
trainer.log_metrics("train", metrics) | |
trainer.save_metrics("train", metrics) | |
trainer.save_state() | |
# Evaluation | |
if training_args.do_eval: | |
if training_args.load_states_in_eval_from_model_path: | |
trainer._load_from_checkpoint(model_args.model_name_or_path) | |
trainer.state = TrainerState.load_from_json( | |
os.path.join(model_args.model_name_or_path, "trainer_state.json") | |
) | |
trainer._load_rng_state(model_args.model_name_or_path) | |
# np.save("weights.npy", model.vocab_to_hidden_dim_embed.weight.data.numpy()) | |
logger.info("*** Evaluate ***") | |
metrics = trainer.evaluate() | |
max_eval_samples = ( | |
data_args.max_eval_samples | |
if data_args.max_eval_samples is not None | |
else len(eval_dataset) | |
) | |
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) | |
trainer.log_metrics("eval", metrics) | |
trainer.save_metrics("eval", metrics) | |
if __name__ == "__main__": | |
main() | |