|
import numpy as np |
|
import torch |
|
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score |
|
from transformers import TrainingArguments, Trainer |
|
from transformers import EarlyStoppingCallback |
|
import pickle as pkl |
|
from datetime import datetime |
|
|
|
|
|
class Dataset(torch.utils.data.Dataset): |
|
def __init__(self, encodings, labels=None): |
|
self.encodings = encodings |
|
self.labels = labels |
|
|
|
def __getitem__(self, idx): |
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
|
item["labels"] = torch.tensor(self.labels[idx]) |
|
return item |
|
|
|
def __len__(self): |
|
return len(self.encodings["input_ids"]) |
|
|
|
def compute_metrics(p): |
|
pred, labels = p |
|
pred = np.argmax(pred, axis=1) |
|
|
|
accuracy = accuracy_score(y_true=labels, y_pred=pred) |
|
recall = recall_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) |
|
precision = precision_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) |
|
f1 = f1_score(y_true=labels, y_pred=pred, average="macro", zero_division=0) |
|
|
|
return {"eval_accuracy": accuracy, "eval_precision": precision, "eval_recall": recall, "eval_f1": f1} |
|
|
|
def train(model, train_dataset, val_dataset, output_dir, save_steps, num_train_epochs=10): |
|
args = TrainingArguments( |
|
output_dir=output_dir, |
|
overwrite_output_dir=True, |
|
evaluation_strategy="steps", |
|
eval_steps=save_steps, |
|
per_device_train_batch_size=16, |
|
per_device_eval_batch_size=16, |
|
num_train_epochs=num_train_epochs, |
|
seed=0, |
|
save_steps=save_steps, |
|
save_total_limit=2, |
|
load_best_model_at_end=True, |
|
metric_for_best_model='eval_f1' |
|
) |
|
trainer = Trainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=val_dataset, |
|
compute_metrics=compute_metrics, |
|
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)] |
|
) |
|
|
|
res = trainer.train() |
|
|