Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from transformers import get_linear_schedule_with_warmup | |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
from typing import Dict, List, Tuple | |
import numpy as np | |
from tqdm import tqdm | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ModelTrainer: | |
def __init__(self, | |
model: nn.Module, | |
device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
learning_rate: float = 2e-5, | |
num_epochs: int = 10, | |
early_stopping_patience: int = 3): | |
self.model = model.to(device) | |
self.device = device | |
self.learning_rate = learning_rate | |
self.num_epochs = num_epochs | |
self.early_stopping_patience = early_stopping_patience | |
self.criterion = nn.CrossEntropyLoss() | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=learning_rate | |
) | |
def train_epoch(self, train_loader: DataLoader) -> float: | |
"""Train for one epoch.""" | |
self.model.train() | |
total_loss = 0 | |
for batch in tqdm(train_loader, desc="Training"): | |
input_ids = batch['input_ids'].to(self.device) | |
attention_mask = batch['attention_mask'].to(self.device) | |
labels = batch['labels'].to(self.device) | |
self.optimizer.zero_grad() | |
outputs = self.model(input_ids, attention_mask) | |
loss = self.criterion(outputs['logits'], labels) | |
loss.backward() | |
self.optimizer.step() | |
total_loss += loss.item() | |
return total_loss / len(train_loader) | |
def evaluate(self, eval_loader: DataLoader) -> Tuple[float, Dict[str, float]]: | |
"""Evaluate the model.""" | |
self.model.eval() | |
total_loss = 0 | |
all_preds = [] | |
all_labels = [] | |
with torch.no_grad(): | |
for batch in tqdm(eval_loader, desc="Evaluating"): | |
input_ids = batch['input_ids'].to(self.device) | |
attention_mask = batch['attention_mask'].to(self.device) | |
labels = batch['labels'].to(self.device) | |
outputs = self.model(input_ids, attention_mask) | |
loss = self.criterion(outputs['logits'], labels) | |
total_loss += loss.item() | |
preds = torch.argmax(outputs['logits'], dim=1) | |
all_preds.extend(preds.cpu().numpy()) | |
all_labels.extend(labels.cpu().numpy()) | |
# Calculate metrics | |
metrics = self._calculate_metrics(all_labels, all_preds) | |
metrics['loss'] = total_loss / len(eval_loader) | |
return total_loss / len(eval_loader), metrics | |
def _calculate_metrics(self, labels: List[int], preds: List[int]) -> Dict[str, float]: | |
"""Calculate evaluation metrics.""" | |
precision, recall, f1, _ = precision_recall_fscore_support( | |
labels, preds, average='weighted' | |
) | |
accuracy = accuracy_score(labels, preds) | |
return { | |
'accuracy': accuracy, | |
'precision': precision, | |
'recall': recall, | |
'f1': f1 | |
} | |
def train(self, | |
train_loader: DataLoader, | |
val_loader: DataLoader, | |
num_training_steps: int) -> Dict[str, List[float]]: | |
"""Train the model with early stopping.""" | |
scheduler = get_linear_schedule_with_warmup( | |
self.optimizer, | |
num_warmup_steps=0, | |
num_training_steps=num_training_steps | |
) | |
best_val_loss = float('inf') | |
patience_counter = 0 | |
history = { | |
'train_loss': [], | |
'val_loss': [], | |
'val_metrics': [] | |
} | |
for epoch in range(self.num_epochs): | |
logger.info(f"Epoch {epoch + 1}/{self.num_epochs}") | |
# Training | |
train_loss = self.train_epoch(train_loader) | |
history['train_loss'].append(train_loss) | |
# Validation | |
val_loss, val_metrics = self.evaluate(val_loader) | |
history['val_loss'].append(val_loss) | |
history['val_metrics'].append(val_metrics) | |
logger.info(f"Train Loss: {train_loss:.4f}") | |
logger.info(f"Val Loss: {val_loss:.4f}") | |
logger.info(f"Val Metrics: {val_metrics}") | |
# Early stopping | |
if val_loss < best_val_loss: | |
best_val_loss = val_loss | |
patience_counter = 0 | |
# Save best model | |
torch.save(self.model.state_dict(), 'best_model.pt') | |
else: | |
patience_counter += 1 | |
if patience_counter >= self.early_stopping_patience: | |
logger.info("Early stopping triggered") | |
break | |
scheduler.step() | |
return history | |
def predict(self, test_loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]: | |
"""Get predictions on test data.""" | |
self.model.eval() | |
all_preds = [] | |
all_probs = [] | |
with torch.no_grad(): | |
for batch in tqdm(test_loader, desc="Predicting"): | |
input_ids = batch['input_ids'].to(self.device) | |
attention_mask = batch['attention_mask'].to(self.device) | |
probs = self.model.predict(input_ids, attention_mask) | |
preds = torch.argmax(probs, dim=1) | |
all_preds.extend(preds.cpu().numpy()) | |
all_probs.extend(probs.cpu().numpy()) | |
return np.array(all_preds), np.array(all_probs) |