Ivan Tan
Init repo for TIABotV2
414714e
# This is the script that will be used in the training container
import argparse
import logging
import os
import sys
import numpy as np
import nltk
try:
nltk.data.find("tokenizers/punkt")
except LookupError as e:
print(e)
try:
nltk.download("punkt")
except FileExistsError as e:
print(e)
pass
from nltk import sent_tokenize
from datasets import load_metric, load_from_disk
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
def tokenize(batch, text_column, target_column, max_source, max_target):
tokenized_input = tokenizer(
batch[text_column], padding="max_length", truncation=True, max_length=max_source
)
tokenized_target = tokenizer(
batch[target_column],
padding="max_length",
truncation=True,
max_length=max_target,
)
tokenized_input["labels"] = tokenized_target["input_ids"]
return tokenized_input
def load_and_tokenize_dataset(
data_dir, split, text_column, target_column, max_source, max_target
):
dataset = load_from_disk(os.path.join(data_dir, split))
tokenized_dataset = dataset.map(
lambda x: tokenize(x, text_column, target_column, max_source, max_target),
batched=True,
batch_size=512,
)
tokenized_dataset.set_format(
"numpy", columns=["input_ids", "attention_mask", "labels"]
)
return tokenized_dataset
def compute_metrics(eval_pred):
metric = load_metric('glue', 'mrpc')
predictions, references = eval_pred
return metric.compute(predictions=predictions, references=references)
# metric = load_metric("rouge")
# predictions, labels = eval_pred
# decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# # Replace -100 in the labels as we can't decode them.
# labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# # Rouge expects a newline after each sentence
# decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip()))
# for pred in decoded_preds]
# decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip()))
# for label in decoded_labels]
# # Compute ROUGE scores
# logger.info("Decoded preds: %s" % decoded_preds)
# logger.info("Decoded labels: %s" % decoded_labels)
# result = metric.compute(predictions=decoded_preds, references=decoded_labels,
# use_stemmer=True)
# # Extract ROUGE f1 scores
# result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
# # Add mean generated length to metrics
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id)
# for pred in predictions]
# result["gen_len"] = np.mean(prediction_lens)
# return {k: round(v, 4) for k, v in result.items()}
def train(args):
from transformers import T5ForConditionalGeneration, T5Tokenizer
logger.info("Loading tokenizer...\n")
global tokenizer
global model_name
model_name = args.model_name
logger.info("Loading pretrained model\n")
if "google" in model_name:
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("Pretrained model loaded\n")
logger.info("Fetching and tokenizing data for training")
train_dataset = load_and_tokenize_dataset(
args.train_data_dir,
"train",
args.text_column,
args.target_column,
args.max_source,
args.max_target,
)
logger.info("Tokenizing data for training loaded")
eval_dataset = load_and_tokenize_dataset(
args.train_data_dir,
"validation",
args.text_column,
args.target_column,
args.max_source,
args.max_target,
)
test_dataset = load_and_tokenize_dataset(
args.train_data_dir,
"test",
args.text_column,
args.target_column,
args.max_source,
args.max_target,
)
logger.info("Defining training arguments\n")
training_args = Seq2SeqTrainingArguments(
output_dir=args.model_dir,
num_train_epochs=args.epoch,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
learning_rate=args.lr,
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
logging_dir=args.log_dir,
logging_strategy=args.logging_strategy,
load_best_model_at_end=True,
adafactor=True,
do_train=True,
do_eval=True,
do_predict=True,
save_total_limit=3,
evaluation_strategy="epoch",
save_strategy="epoch",
predict_with_generate=True,
metric_for_best_model="eval_loss",
seed=7,
)
logger.info("Defining seq2seq Trainer")
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer#,compute_metrics=compute_metrics,
)
logger.info("Starting Training")
trainer.train()
logger.info("Model trained successfully")
trainer.save_model()
logger.info("Model saved successfully")
# Evaluation
logger.info("*** Evaluate on test set***")
logger.info(trainer.predict(test_dataset))
logger.info("Removing unused checkpoints to save space in container")
os.system(f"rm -rf {args.model_dir}/checkpoint-*/")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="google/pegasus-xsum")
parser.add_argument(
"--train-data-dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"]
)
# parser.add_argument("--val-data-dir", type=str,
# default=os.environ["SM_CHANNEL_VALIDATION"])
# parser.add_argument("--test-data-dir", type=str,
# default=os.environ["SM_CHANNEL_TEST"])
parser.add_argument("--text-column", type=str, default="dialogue")
parser.add_argument("--target-column", type=str, default="summary")
parser.add_argument("--max-source", type=int, default=512)
parser.add_argument("--max-target", type=int, default=80)
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--train-batch-size", type=int, default=2)
parser.add_argument("--eval-batch-size", type=int, default=2)
parser.add_argument("--warmup-steps", type=float, default=500)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--weight-decay", type=float, default=0.0)
parser.add_argument("--log-dir", type=str, default=os.environ["SM_OUTPUT_DIR"])
parser.add_argument("--logging-strategy", type=str, default="epoch")
train(parser.parse_args())