import torch from torch.utils.data import DataLoader from transformers import get_linear_schedule_with_warmup from tqdm import tqdm import logging from pathlib import Path import numpy as np from sklearn.metrics import f1_score, precision_score, recall_score import json from datetime import datetime class NarrativeTrainer: """ Comprehensive trainer for narrative classification with GPU support. """ def __init__( self, model, train_dataset, val_dataset, config, ): self.setup_logging() self.logger = logging.getLogger(__name__) # Set device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.logger.info(f"Using device: {self.device}") # Initialize model and components self.model = model.to(self.device) self.train_dataset = train_dataset self.val_dataset = val_dataset self.config = config self.current_epoch = 0 self.global_step = 0 self.best_val_f1 = 0.0 self.setup_training() self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.output_dir = Path(config.output_dir) / self.timestamp self.output_dir.mkdir(parents=True, exist_ok=True) self.save_config() self.history = { 'train_loss': [], 'val_loss': [], 'val_f1': [], 'val_precision': [], 'val_recall': [] } def setup_logging(self): logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) def setup_training(self): """Initialize dataloaders, optimizer, and scheduler.""" self.train_loader = DataLoader( self.train_dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=4 ) self.val_loader = DataLoader( self.val_dataset, batch_size=self.config.batch_size, num_workers=4 ) self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay ) num_training_steps = len(self.train_loader) * self.config.num_epochs num_warmup_steps = int(num_training_steps * self.config.warmup_ratio) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) self.criterion = torch.nn.BCEWithLogitsLoss() def save_config(self): """Save training configuration.""" config_dict = {k: str(v) for k, v in vars(self.config).items()} config_path = self.output_dir / 'config.json' with open(config_path, 'w') as f: json.dump(config_dict, f, indent=4) def train_epoch(self): """Train model for one epoch.""" self.model.train() total_loss = 0 pbar = tqdm(self.train_loader, desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}') for batch in pbar: batch = {k: v.to(self.device) for k, v in batch.items()} self.optimizer.zero_grad() outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], features=batch['features'] ) loss = self.criterion(outputs, batch['labels']) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.optimizer.step() self.scheduler.step() total_loss += loss.item() pbar.set_postfix({'loss': total_loss / (pbar.n + 1)}) self.global_step += 1 if self.global_step % self.config.eval_steps == 0: self.evaluate() return total_loss / len(self.train_loader) @torch.no_grad() def evaluate(self): """Evaluate model performance.""" self.model.eval() total_loss = 0 all_preds, all_labels = [], [] for batch in tqdm(self.val_loader, desc="Evaluating"): batch = {k: v.to(self.device) for k, v in batch.items()} outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], features=batch['features'] ) loss = self.criterion(outputs, batch['labels']) total_loss += loss.item() preds = torch.sigmoid(outputs) > 0.5 all_preds.append(preds.cpu().numpy()) all_labels.append(batch['labels'].cpu().numpy()) all_preds = np.concatenate(all_preds, axis=0) all_labels = np.concatenate(all_labels, axis=0) metrics = { 'loss': total_loss / len(self.val_loader), 'f1': f1_score(all_labels, all_preds, average='micro'), 'precision': precision_score(all_labels, all_preds, average='micro'), 'recall': recall_score(all_labels, all_preds, average='micro') } self.logger.info(f"Step {self.global_step} - Validation metrics: {metrics}") if metrics['f1'] > self.best_val_f1: self.best_val_f1 = metrics['f1'] self.save_model('best_model.pt', metrics) return metrics def save_model(self, filename: str, metrics: dict = None): save_path = self.output_dir / filename torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'epoch': self.current_epoch, 'global_step': self.global_step, 'best_val_f1': self.best_val_f1, 'metrics': metrics }, save_path) self.logger.info(f"Model saved to {save_path}") def train(self): """Run training for all epochs.""" self.logger.info("Starting training...") for epoch in range(self.config.num_epochs): self.current_epoch = epoch train_loss = self.train_epoch() self.history['train_loss'].append(train_loss) val_metrics = self.evaluate() self.history['val_loss'].append(val_metrics['loss']) self.history['val_f1'].append(val_metrics['f1']) self.history['val_precision'].append(val_metrics['precision']) self.history['val_recall'].append(val_metrics['recall']) self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics) history_path = self.output_dir / 'history.json' with open(history_path, 'w') as f: json.dump(self.history, f, indent=4) self.logger.info("Training completed!") return self.history if __name__ == "__main__": import sys sys.path.append("../../") from scripts.models.model import NarrativeClassifier from scripts.models.dataset import NarrativeDataset from scripts.config.config import TrainingConfig from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor # Initialize training configuration config = TrainingConfig( output_dir=Path("./output"), num_epochs=5, batch_size=32, learning_rate=5e-5, weight_decay=0.01, warmup_ratio=0.1, max_grad_norm=1.0, eval_steps=100 ) # Load and process data processor = AdvancedNarrativeProcessor( annotations_file="../../data/subtask-2-annotations.txt", raw_dir="../../data/raw" ) processed_data = processor.load_and_process_data() # Create datasets train_dataset = NarrativeDataset(processed_data['train']) val_dataset = NarrativeDataset(processed_data['val']) # Initialize model model = NarrativeClassifier(num_labels=train_dataset.get_num_labels()) # Initialize trainer trainer = NarrativeTrainer( model=model, train_dataset=train_dataset, val_dataset=val_dataset, config=config ) # Start full training print("\n=== Starting Training ===") trainer.train() print("\nTraining completed successfully!")