# This is the script that will be used in the training container import argparse import logging import os import sys import numpy as np import nltk try: nltk.data.find("tokenizers/punkt") except LookupError as e: print(e) try: nltk.download("punkt") except FileExistsError as e: print(e) pass from nltk import sent_tokenize from datasets import load_metric, load_from_disk from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.addHandler(logging.StreamHandler(sys.stdout)) def tokenize(batch, text_column, target_column, max_source, max_target): tokenized_input = tokenizer( batch[text_column], padding="max_length", truncation=True, max_length=max_source ) tokenized_target = tokenizer( batch[target_column], padding="max_length", truncation=True, max_length=max_target, ) tokenized_input["labels"] = tokenized_target["input_ids"] return tokenized_input def load_and_tokenize_dataset( data_dir, split, text_column, target_column, max_source, max_target ): dataset = load_from_disk(os.path.join(data_dir, split)) tokenized_dataset = dataset.map( lambda x: tokenize(x, text_column, target_column, max_source, max_target), batched=True, batch_size=512, ) tokenized_dataset.set_format( "numpy", columns=["input_ids", "attention_mask", "labels"] ) return tokenized_dataset def compute_metrics(eval_pred): metric = load_metric('glue', 'mrpc') predictions, references = eval_pred return metric.compute(predictions=predictions, references=references) # metric = load_metric("rouge") # predictions, labels = eval_pred # decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) # # Replace -100 in the labels as we can't decode them. # labels = np.where(labels != -100, labels, tokenizer.pad_token_id) # decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # # Rouge expects a newline after each sentence # decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) # for pred in decoded_preds] # decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) # for label in decoded_labels] # # Compute ROUGE scores # logger.info("Decoded preds: %s" % decoded_preds) # logger.info("Decoded labels: %s" % decoded_labels) # result = metric.compute(predictions=decoded_preds, references=decoded_labels, # use_stemmer=True) # # Extract ROUGE f1 scores # result = {key: value.mid.fmeasure * 100 for key, value in result.items()} # # Add mean generated length to metrics # prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) # for pred in predictions] # result["gen_len"] = np.mean(prediction_lens) # return {k: round(v, 4) for k, v in result.items()} def train(args): from transformers import T5ForConditionalGeneration, T5Tokenizer logger.info("Loading tokenizer...\n") global tokenizer global model_name model_name = args.model_name logger.info("Loading pretrained model\n") if "google" in model_name: model = T5ForConditionalGeneration.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name) else: model = AutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info("Pretrained model loaded\n") logger.info("Fetching and tokenizing data for training") train_dataset = load_and_tokenize_dataset( args.train_data_dir, "train", args.text_column, args.target_column, args.max_source, args.max_target, ) logger.info("Tokenizing data for training loaded") eval_dataset = load_and_tokenize_dataset( args.train_data_dir, "validation", args.text_column, args.target_column, args.max_source, args.max_target, ) test_dataset = load_and_tokenize_dataset( args.train_data_dir, "test", args.text_column, args.target_column, args.max_source, args.max_target, ) logger.info("Defining training arguments\n") training_args = Seq2SeqTrainingArguments( output_dir=args.model_dir, num_train_epochs=args.epoch, per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, learning_rate=args.lr, warmup_steps=args.warmup_steps, weight_decay=args.weight_decay, logging_dir=args.log_dir, logging_strategy=args.logging_strategy, load_best_model_at_end=True, adafactor=True, do_train=True, do_eval=True, do_predict=True, save_total_limit=3, evaluation_strategy="epoch", save_strategy="epoch", predict_with_generate=True, metric_for_best_model="eval_loss", seed=7, ) logger.info("Defining seq2seq Trainer") trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer#,compute_metrics=compute_metrics, ) logger.info("Starting Training") trainer.train() logger.info("Model trained successfully") trainer.save_model() logger.info("Model saved successfully") # Evaluation logger.info("*** Evaluate on test set***") logger.info(trainer.predict(test_dataset)) logger.info("Removing unused checkpoints to save space in container") os.system(f"rm -rf {args.model_dir}/checkpoint-*/") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="google/pegasus-xsum") parser.add_argument( "--train-data-dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"] ) # parser.add_argument("--val-data-dir", type=str, # default=os.environ["SM_CHANNEL_VALIDATION"]) # parser.add_argument("--test-data-dir", type=str, # default=os.environ["SM_CHANNEL_TEST"]) parser.add_argument("--text-column", type=str, default="dialogue") parser.add_argument("--target-column", type=str, default="summary") parser.add_argument("--max-source", type=int, default=512) parser.add_argument("--max-target", type=int, default=80) parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--train-batch-size", type=int, default=2) parser.add_argument("--eval-batch-size", type=int, default=2) parser.add_argument("--warmup-steps", type=float, default=500) parser.add_argument("--lr", type=float, default=2e-5) parser.add_argument("--weight-decay", type=float, default=0.0) parser.add_argument("--log-dir", type=str, default=os.environ["SM_OUTPUT_DIR"]) parser.add_argument("--logging-strategy", type=str, default="epoch") train(parser.parse_args())