opdx / helpers /trainer_embedder.py
lyangas
missed files
6931ba0
raw
history blame
2.01 kB
import numpy as np
import torch
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
import pickle as pkl
from datetime import datetime
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.encodings["input_ids"])
def compute_metrics(p):
pred, labels = p
pred = np.argmax(pred, axis=1)
accuracy = accuracy_score(y_true=labels, y_pred=pred)
recall = recall_score(y_true=labels, y_pred=pred, average='macro', zero_division=0)
precision = precision_score(y_true=labels, y_pred=pred, average='macro', zero_division=0)
f1 = f1_score(y_true=labels, y_pred=pred, average="macro", zero_division=0)
return {"eval_accuracy": accuracy, "eval_precision": precision, "eval_recall": recall, "eval_f1": f1}
def train(model, train_dataset, val_dataset, output_dir, save_steps, num_train_epochs=10):
args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
evaluation_strategy="steps",
eval_steps=save_steps,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=num_train_epochs,
seed=0,
save_steps=save_steps,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model='eval_f1'
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)
res = trainer.train()