In [1]:
from datasets import load_dataset

billsum = load_dataset("billsum", split="ca_test")
billsum = billsum.select(range(1000))
billsum = billsum.train_test_split(test_size=0.2)

In [2]:
from transformers import AutoTokenizer
checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True, padding="max_length") 

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [3]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [4]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [5]:
import torch

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")
    model.to(mps_device)
    print("Model moved to MPS device")

Model moved to MPS device


In [6]:
training_args = Seq2SeqTrainingArguments(
    output_dir="calendar_model",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    use_mps_device=True,
    # fp16=True,
    # push_to_hub=True,
)



In [7]:
import numpy as np
import evaluate
metric = evaluate.load("accuracy")


In [8]:
def compute_metrics(eval_pred):
     logits, labels = eval_pred
     predictions = np.argmax(logits, axis=-1)
     return metric.compute(predictions=predictions, references=labels)

In [9]:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

In [10]:
trainer = Trainer(
     model=model,
     args=training_args,
     train_dataset=tokenized_billsum["train"],
     eval_dataset=tokenized_billsum["test"],
     compute_metrics=compute_metrics,
 )

In [11]:
trainer.train()

  0%|          | 0/300 [00:00<?, ?it/s]

KeyboardInterrupt: 