import torch import torch.nn as nn from torch.utils.data import DataLoader from transformers import get_linear_schedule_with_warmup from sklearn.metrics import accuracy_score, precision_recall_fscore_support from typing import Dict, List, Tuple import numpy as np from tqdm import tqdm import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelTrainer: def __init__(self, model: nn.Module, device: str = "cuda" if torch.cuda.is_available() else "cpu", learning_rate: float = 2e-5, num_epochs: int = 10, early_stopping_patience: int = 3): self.model = model.to(device) self.device = device self.learning_rate = learning_rate self.num_epochs = num_epochs self.early_stopping_patience = early_stopping_patience self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=learning_rate ) def train_epoch(self, train_loader: DataLoader) -> float: """Train for one epoch.""" self.model.train() total_loss = 0 for batch in tqdm(train_loader, desc="Training"): input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['labels'].to(self.device) self.optimizer.zero_grad() outputs = self.model(input_ids, attention_mask) loss = self.criterion(outputs['logits'], labels) loss.backward() self.optimizer.step() total_loss += loss.item() return total_loss / len(train_loader) def evaluate(self, eval_loader: DataLoader) -> Tuple[float, Dict[str, float]]: """Evaluate the model.""" self.model.eval() total_loss = 0 all_preds = [] all_labels = [] with torch.no_grad(): for batch in tqdm(eval_loader, desc="Evaluating"): input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['labels'].to(self.device) outputs = self.model(input_ids, attention_mask) loss = self.criterion(outputs['logits'], labels) total_loss += loss.item() preds = torch.argmax(outputs['logits'], dim=1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # Calculate metrics metrics = self._calculate_metrics(all_labels, all_preds) metrics['loss'] = total_loss / len(eval_loader) return total_loss / len(eval_loader), metrics def _calculate_metrics(self, labels: List[int], preds: List[int]) -> Dict[str, float]: """Calculate evaluation metrics.""" precision, recall, f1, _ = precision_recall_fscore_support( labels, preds, average='weighted' ) accuracy = accuracy_score(labels, preds) return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1 } def train(self, train_loader: DataLoader, val_loader: DataLoader, num_training_steps: int) -> Dict[str, List[float]]: """Train the model with early stopping.""" scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=0, num_training_steps=num_training_steps ) best_val_loss = float('inf') patience_counter = 0 history = { 'train_loss': [], 'val_loss': [], 'val_metrics': [] } for epoch in range(self.num_epochs): logger.info(f"Epoch {epoch + 1}/{self.num_epochs}") # Training train_loss = self.train_epoch(train_loader) history['train_loss'].append(train_loss) # Validation val_loss, val_metrics = self.evaluate(val_loader) history['val_loss'].append(val_loss) history['val_metrics'].append(val_metrics) logger.info(f"Train Loss: {train_loss:.4f}") logger.info(f"Val Loss: {val_loss:.4f}") logger.info(f"Val Metrics: {val_metrics}") # Early stopping if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 # Save best model torch.save(self.model.state_dict(), 'best_model.pt') else: patience_counter += 1 if patience_counter >= self.early_stopping_patience: logger.info("Early stopping triggered") break scheduler.step() return history def predict(self, test_loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]: """Get predictions on test data.""" self.model.eval() all_preds = [] all_probs = [] with torch.no_grad(): for batch in tqdm(test_loader, desc="Predicting"): input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) probs = self.model.predict(input_ids, attention_mask) preds = torch.argmax(probs, dim=1) all_preds.extend(preds.cpu().numpy()) all_probs.extend(probs.cpu().numpy()) return np.array(all_preds), np.array(all_probs)