Spaces:
Sleeping
Sleeping
# run_clm.py | |
import logging | |
import os | |
import sys | |
import datasets | |
import transformers | |
from transformers import ( | |
Trainer, | |
default_data_collator, | |
is_torch_tpu_available, | |
set_seed, | |
) | |
from transformers.trainer_callback import TrainerState | |
from transformers.utils import check_min_version | |
from transformers.utils.versions import require_version | |
from sdlm.run_pretrain import filter_by_length | |
from .arguments import get_args | |
from .data.data_utils import load_data, tokenize_data_new | |
from .models import load_model | |
from .utils import ( | |
get_last_checkpoint_with_beaker_preemption, | |
is_nfs_available, | |
is_weka_available, | |
resolve_last_checkpoint_vs_resume_from_checkpoint, | |
set_hf_home, | |
set_pretraining_dataset, | |
) | |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. | |
check_min_version("4.25.0") | |
require_version( | |
"datasets>=2.0.0", | |
"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt", | |
) | |
logger = logging.getLogger(__name__) | |
# set environment variables | |
set_hf_home() | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def main(): | |
# parse args | |
model_args, data_args, training_args, diffusion_args = get_args() | |
set_pretraining_dataset(data_args) | |
# Setup logging | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
if training_args.should_log: | |
# The default of training_args.log_level is passive, so we set log level at info here to have that default. | |
transformers.utils.logging.set_verbosity_info() | |
log_level = training_args.get_process_log_level() | |
logger.setLevel(log_level) | |
datasets.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.enable_default_handler() | |
transformers.utils.logging.enable_explicit_format() | |
# Log on each process the small summary: | |
logger.warning( | |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " | |
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" | |
) | |
logger.info(f"Training/evaluation parameters {training_args}") | |
# Detecting last checkpoint. | |
last_checkpoint = get_last_checkpoint_with_beaker_preemption(training_args) | |
# load model | |
tokenizer, model = load_model( | |
model_args, data_args, training_args, diffusion_args, logger | |
) | |
assert model.config.pad_token_id is not None | |
# Set seed before initializing model. | |
set_seed(training_args.seed) | |
if training_args.do_train: | |
raw_datasets = load_data(data_args, model_args) | |
train_dataset = tokenize_data_new( | |
data_args, tokenizer, raw_datasets, training_args | |
)["train"] | |
if data_args.max_train_samples is not None: | |
max_train_samples = min(len(train_dataset), data_args.max_train_samples) | |
train_dataset = train_dataset.select(range(max_train_samples)) | |
if data_args.min_train_seq_length != 0: | |
train_dataset = train_dataset.filter( | |
filter_by_length( | |
data_args.min_train_seq_length, model.config.pad_token_id | |
) | |
) | |
if data_args.shuffle and data_args.streaming: | |
train_dataset = train_dataset.shuffle( | |
seed=training_args.seed, buffer_size=10_000 | |
) | |
elif data_args.shuffle: | |
train_dataset = train_dataset.shuffle(seed=training_args.seed) | |
# NOTE: modifications for clm | |
train_dataset = train_dataset.map( | |
lambda x: {**x, "labels": x["input_ids"]}, | |
remove_columns=["special_tokens_mask"], | |
) | |
if training_args.do_eval: | |
# default to c4 | |
if is_weka_available(): | |
data_file_path = "/data/input/jaket/c4_subset" | |
elif is_nfs_available(): | |
data_file_path = ( | |
"/net/nfs.cirrascale/allennlp/jaket/simplex-diffusion/c4_subset" | |
) | |
else: | |
# yale | |
data_file_path = "/home/jt856/documents/simplex-diffusion/raw/c4_subset" | |
c4_raw_dataset = datasets.IterableDatasetDict( | |
{ | |
"validation": datasets.load_dataset( | |
"json", | |
data_files=os.path.join( | |
data_file_path, "c4-validation.00000-of-00008.json" | |
), | |
)["train"] | |
} | |
) | |
c4_tokenized_datasets = tokenize_data_new( | |
data_args, tokenizer, c4_raw_dataset, training_args | |
) | |
eval_dataset = c4_tokenized_datasets["validation"] | |
if data_args.max_eval_samples is not None: | |
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) | |
eval_dataset = eval_dataset.select(range(max_eval_samples)) | |
if data_args.min_eval_seq_length != 0: | |
eval_dataset = eval_dataset.filter( | |
filter_by_length( | |
data_args.min_eval_seq_length, model.config.pad_token_id | |
), | |
num_proc=data_args.preprocessing_num_workers, | |
) | |
# NOTE: modifications for clm | |
eval_dataset = eval_dataset.map( | |
lambda x: {**x, "labels": x["input_ids"]}, | |
remove_columns=["special_tokens_mask"], | |
) | |
def preprocess_logits_for_metrics(logits, labels): | |
if isinstance(logits, tuple): | |
# Depending on the model and config, logits may contain extra tensors, | |
# like past_key_values, but logits always come first | |
logits = logits[0] | |
return logits.argmax(dim=-1) | |
# Initialize our Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset if training_args.do_train else None, | |
eval_dataset=eval_dataset if training_args.do_eval else None, | |
tokenizer=tokenizer, | |
# Data collator will default to DataCollatorWithPadding, so we change it. | |
data_collator=default_data_collator, | |
compute_metrics=None, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics | |
if training_args.do_eval and not is_torch_tpu_available() | |
else None, | |
) | |
# Training | |
if training_args.do_train: | |
checkpoint = resolve_last_checkpoint_vs_resume_from_checkpoint( | |
last_checkpoint, | |
training_args.resume_from_checkpoint, | |
) | |
train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
trainer.save_model() # Saves the tokenizer too for easy upload | |
metrics = train_result.metrics | |
trainer.log_metrics("train", metrics) | |
trainer.save_metrics("train", metrics) | |
trainer.save_state() | |
# Evaluation | |
if training_args.do_eval: | |
if training_args.load_states_in_eval_from_model_path: | |
trainer._load_from_checkpoint(model_args.model_name_or_path) | |
trainer.state = TrainerState.load_from_json( | |
os.path.join(model_args.model_name_or_path, "trainer_state.json") | |
) | |
trainer._load_rng_state(model_args.model_name_or_path) | |
logger.info("*** Evaluate ***") | |
metrics = trainer.evaluate() | |
max_eval_samples = ( | |
data_args.max_eval_samples | |
if data_args.max_eval_samples is not None | |
else len(eval_dataset) | |
) | |
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) | |
trainer.log_metrics("eval", metrics) | |
trainer.save_metrics("eval", metrics) | |
if __name__ == "__main__": | |
main() | |