Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score | |
from transformers import TrainingArguments, Trainer | |
from transformers import EarlyStoppingCallback | |
import pickle as pkl | |
from datetime import datetime | |
class Dataset(torch.utils.data.Dataset): | |
def __init__(self, encodings, labels=None): | |
self.encodings = encodings | |
self.labels = labels | |
def __getitem__(self, idx): | |
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
item["labels"] = torch.tensor(self.labels[idx]) | |
return item | |
def __len__(self): | |
return len(self.encodings["input_ids"]) | |
def compute_metrics(p): | |
pred, labels = p | |
pred = np.argmax(pred, axis=1) | |
accuracy = accuracy_score(y_true=labels, y_pred=pred) | |
recall = recall_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) | |
precision = precision_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) | |
f1 = f1_score(y_true=labels, y_pred=pred, average="macro", zero_division=0) | |
return {"eval_accuracy": accuracy, "eval_precision": precision, "eval_recall": recall, "eval_f1": f1} | |
def train(model, train_dataset, val_dataset, output_dir, save_steps, num_train_epochs=10): | |
args = TrainingArguments( | |
output_dir=output_dir, | |
overwrite_output_dir=True, | |
evaluation_strategy="steps", | |
eval_steps=save_steps, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=num_train_epochs, | |
seed=0, | |
save_steps=save_steps, | |
save_total_limit=2, | |
load_best_model_at_end=True, | |
metric_for_best_model='eval_f1' | |
) | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
compute_metrics=compute_metrics, | |
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)] | |
) | |
res = trainer.train() | |