import numpy as np import torch from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score from transformers import TrainingArguments, Trainer from transformers import EarlyStoppingCallback import pickle as pkl from datetime import datetime class Dataset(torch.utils.data.Dataset): def __init__(self, encodings, labels=None): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item["labels"] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.encodings["input_ids"]) def compute_metrics(p): pred, labels = p pred = np.argmax(pred, axis=1) accuracy = accuracy_score(y_true=labels, y_pred=pred) recall = recall_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) precision = precision_score(y_true=labels, y_pred=pred, average='macro', zero_division=0) f1 = f1_score(y_true=labels, y_pred=pred, average="macro", zero_division=0) return {"eval_accuracy": accuracy, "eval_precision": precision, "eval_recall": recall, "eval_f1": f1} def train(model, train_dataset, val_dataset, output_dir, save_steps, num_train_epochs=10): args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, evaluation_strategy="steps", eval_steps=save_steps, per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=num_train_epochs, seed=0, save_steps=save_steps, save_total_limit=2, load_best_model_at_end=True, metric_for_best_model='eval_f1' ) trainer = Trainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks = [EarlyStoppingCallback(early_stopping_patience=3)] ) res = trainer.train()