import torch from torch.utils.data import DataLoader from transformers import get_linear_schedule_with_warmup from tqdm import tqdm import logging from pathlib import Path import numpy as np from sklearn.metrics import f1_score, precision_score, recall_score import json from datetime import datetime from torch.cuda.amp import autocast, GradScaler class NarrativeTrainer: """Enhanced trainer with detailed metrics and optimizations""" def __init__( self, model, train_dataset, val_dataset, config, ): # Setup basics self.setup_logging() self.logger = logging.getLogger(__name__) # Store config first self.config = config # Setup device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.logger.info(f"Using device: {self.device}") # Clear GPU cache if using CUDA if torch.cuda.is_available(): torch.cuda.empty_cache() # Initialize model and components self.model = model.to(self.device) self.train_dataset = train_dataset self.val_dataset = val_dataset # Initialize training state self.current_epoch = 0 self.global_step = 0 self.best_val_f1 = 0.0 # Initialize mixed precision training (Fixed version) if self.config.fp16: self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None # Setup training components self.setup_training() # Setup output directory self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.output_dir = Path(config.output_dir) / self.timestamp self.output_dir.mkdir(parents=True, exist_ok=True) # Save config and initialize history self.save_config() self.history = { 'train_loss': [], 'val_loss': [], 'metrics': [], 'thresholds': [] } def setup_logging(self): logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) def calculate_class_weights(self): """Calculate weights for imbalanced classes""" pos_counts = self.train_dataset.labels.sum(dim=0) neg_counts = len(self.train_dataset) - pos_counts pos_weight = (neg_counts / pos_counts) * self.config.pos_weight_multiplier return torch.clamp(pos_weight, min=1.0, max=50.0).to(self.device) def setup_training(self): """Initialize training components with optimizations""" # Create dataloaders self.train_loader = DataLoader( self.train_dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=4, pin_memory=True ) self.val_loader = DataLoader( self.val_dataset, batch_size=self.config.batch_size, num_workers=4, pin_memory=True ) # Calculate class weights pos_weight = self.calculate_class_weights() # Setup loss function with class weights only self.criterion = torch.nn.BCEWithLogitsLoss( pos_weight=pos_weight ) # Setup optimizer self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay ) # Setup scheduler num_update_steps_per_epoch = len(self.train_loader) // self.config.gradient_accumulation_steps num_training_steps = num_update_steps_per_epoch * self.config.num_epochs num_warmup_steps = int(num_training_steps * self.config.warmup_ratio) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) # Initialize thresholds self.label_thresholds = torch.ones(self.train_dataset.get_num_labels()).to(self.device) * 0.5 def save_config(self): config_dict = {k: str(v) for k, v in vars(self.config).items()} config_path = self.output_dir / 'config.json' with open(config_path, 'w') as f: json.dump(config_dict, f, indent=4) def find_optimal_thresholds(self, val_outputs, val_labels): """Find optimal threshold for each label""" outputs = torch.sigmoid(val_outputs).cpu().numpy() labels = val_labels.cpu().numpy() thresholds = [] for i in range(labels.shape[1]): best_f1 = 0 best_threshold = 0.5 if labels[:, i].sum() > 0: # Only if we have positive samples for threshold in np.arange(0.1, 0.9, 0.05): preds = (outputs[:, i] > threshold).astype(int) f1 = f1_score(labels[:, i], preds) if f1 > best_f1: best_f1 = f1 best_threshold = threshold thresholds.append(best_threshold) return torch.tensor(thresholds).to(self.device) def calculate_detailed_metrics(self, all_labels, all_preds, all_probs=None): """Calculate detailed metrics for model evaluation""" metrics = {} # Basic metrics metrics['micro'] = { 'precision': precision_score(all_labels, all_preds, average='micro'), 'recall': recall_score(all_labels, all_preds, average='micro'), 'f1': f1_score(all_labels, all_preds, average='micro') } metrics['macro'] = { 'precision': precision_score(all_labels, all_preds, average='macro'), 'recall': recall_score(all_labels, all_preds, average='macro'), 'f1': f1_score(all_labels, all_preds, average='macro') } metrics['weighted'] = { 'precision': precision_score(all_labels, all_preds, average='weighted'), 'recall': recall_score(all_labels, all_preds, average='weighted'), 'f1': f1_score(all_labels, all_preds, average='weighted') } # Per-class metrics per_class_metrics = {} precisions = precision_score(all_labels, all_preds, average=None) recalls = recall_score(all_labels, all_preds, average=None) f1s = f1_score(all_labels, all_preds, average=None) for i in range(len(f1s)): per_class_metrics[f'class_{i}'] = { 'precision': float(precisions[i]), 'recall': float(recalls[i]), 'f1': float(f1s[i]), 'support': int(all_labels[:, i].sum()) } metrics['per_class'] = per_class_metrics return metrics def train_epoch(self): """Train for one epoch with optimizations""" self.model.train() total_loss = 0 self.optimizer.zero_grad() pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}') for step, batch in pbar: batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()} # Mixed precision training with torch.cuda.amp.autocast(enabled=self.config.fp16): outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], features=batch['features'] ) loss = self.criterion(outputs, batch['labels']) loss = loss / self.config.gradient_accumulation_steps # Backward pass with scaler if fp16 is enabled if self.config.fp16: self.scaler.scale(loss).backward() else: loss.backward() if (step + 1) % self.config.gradient_accumulation_steps == 0: if self.config.fp16: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.max_grad_norm ) if self.config.fp16: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() total_loss += loss.item() * self.config.gradient_accumulation_steps avg_loss = total_loss / (step + 1) pbar.set_postfix({'loss': f'{avg_loss:.4f}'}) self.global_step += 1 if self.global_step % self.config.eval_steps == 0: self.evaluate() if step % 10 == 0: torch.cuda.empty_cache() del outputs del loss return total_loss / len(self.train_loader) @torch.no_grad() def evaluate(self): """Evaluate model with detailed metrics""" self.model.eval() total_loss = 0 all_outputs, all_labels = [], [] for batch in tqdm(self.val_loader, desc="Evaluating"): batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()} with autocast(enabled=self.config.fp16): outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], features=batch['features'] ) loss = self.criterion(outputs, batch['labels']) total_loss += loss.item() all_outputs.append(outputs.cpu()) all_labels.append(batch['labels'].cpu()) del outputs del loss torch.cuda.empty_cache() all_outputs = torch.cat(all_outputs, dim=0) all_labels = torch.cat(all_labels, dim=0) if self.global_step % (self.config.eval_steps * 2) == 0: self.label_thresholds = self.find_optimal_thresholds(all_outputs, all_labels) all_probs = torch.sigmoid(all_outputs).numpy() all_preds = (all_probs > self.label_thresholds.cpu().unsqueeze(0).numpy()) all_labels = all_labels.numpy() metrics = self.calculate_detailed_metrics(all_labels, all_preds, all_probs) metrics['loss'] = total_loss / len(self.val_loader) self.logger.info(f"Step {self.global_step} - Validation metrics:") self.logger.info(f"Loss: {metrics['loss']:.4f}") self.logger.info(f"Micro F1: {metrics['micro']['f1']:.4f}") self.logger.info(f"Macro F1: {metrics['macro']['f1']:.4f}") if metrics['micro']['f1'] > self.best_val_f1: self.best_val_f1 = metrics['micro']['f1'] self.save_model('best_model.pt', metrics) return metrics def save_model(self, filename: str, metrics: dict = None): save_path = self.output_dir / filename torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'scaler_state_dict': self.scaler.state_dict(), 'epoch': self.current_epoch, 'global_step': self.global_step, 'best_val_f1': self.best_val_f1, 'metrics': metrics, 'thresholds': self.label_thresholds }, save_path) self.logger.info(f"Model saved to {save_path}") def train(self): """Run complete training loop""" self.logger.info("Starting training...") try: for epoch in range(self.config.num_epochs): self.current_epoch = epoch self.logger.info(f"Starting epoch {epoch + 1}/{self.config.num_epochs}") train_loss = self.train_epoch() self.history['train_loss'].append(train_loss) val_metrics = self.evaluate() self.history['metrics'].append(val_metrics) self.history['thresholds'].append(self.label_thresholds.cpu().tolist()) self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics) history_path = self.output_dir / 'history.json' with open(history_path, 'w') as f: json.dump(self.history, f, indent=4) self.logger.info(f"Epoch {epoch + 1} completed. Train loss: {train_loss:.4f}") self.logger.info("Training completed successfully!") return self.history except Exception as e: self.logger.error(f"Training failed with error: {str(e)}") raise