import torch import torch.nn as nn from transformers import AutoModel, AutoConfig from typing import Dict, Optional import logging from torch.utils.data import Dataset, DataLoader # NarrativeClassifier Model Definition class NarrativeClassifier(nn.Module): """ Production-ready model for narrative classification combining transformer with additional features. """ def __init__( self, model_name: str = "microsoft/deberta-v3-large", num_labels: int = 84, dropout: float = 0.1, freeze_encoder: bool = False, device: Optional[str] = None ): super().__init__() self.logger = logging.getLogger(__name__) self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') self.logger.info(f"Using device: {self.device}") self.config = AutoConfig.from_pretrained(model_name) try: self.transformer = AutoModel.from_pretrained(model_name, config=self.config) except Exception as e: self.logger.error(f"Error loading transformer model: {str(e)}") raise if freeze_encoder: self.logger.info("Freezing transformer encoder") for param in self.transformer.parameters(): param.requires_grad = False self.transformer_dim = self.transformer.config.hidden_size self.num_features = 5 # Additional numerical features self.feature_processor = nn.Sequential( nn.Linear(self.num_features, 64), nn.LayerNorm(64), nn.ReLU(), nn.Dropout(dropout) ) self.pre_classifier = nn.Sequential( nn.Linear(self.transformer_dim + 64, 512), nn.LayerNorm(512), nn.ReLU(), nn.Dropout(dropout) ) self.classifier = nn.Linear(512, num_labels) self._init_weights() self.to(self.device) def _init_weights(self): """Initialize weights for added layers.""" for module in [self.feature_processor, self.pre_classifier, self.classifier]: for layer in module.modules(): if isinstance(layer, nn.Linear): torch.nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: torch.nn.init.zeros_(layer.bias) def forward(self, input_ids, attention_mask, features, return_dict=False): transformer_outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) sequence_output = transformer_outputs.last_hidden_state[:, 0, :] processed_features = self.feature_processor(features) combined = torch.cat([sequence_output, processed_features], dim=1) intermediate = self.pre_classifier(combined) logits = self.classifier(intermediate) if return_dict: return { 'logits': logits, 'hidden_states': intermediate, 'transformer_output': sequence_output } return logits # Dataset Definition class NarrativeDataset(Dataset): def __init__(self, data): """ Initialize dataset with data. Args: data: List of dictionaries containing input_ids, attention_mask, and features. """ self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): """ Return one sample. """ sample = self.data[idx] return { 'input_ids': torch.tensor(sample['input_ids'], dtype=torch.long), 'attention_mask': torch.tensor(sample['attention_mask'], dtype=torch.long), 'features': torch.tensor(sample['features'], dtype=torch.float) } def get_num_labels(self): """ Return the number of labels (for classification tasks). """ return max(item['label'] for item in self.data) + 1 # Main Testing Section if __name__ == "__main__": import sys sys.path.append("../../") from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor from scripts.models.dataset import NarrativeDataset from torch.utils.data import DataLoader # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info(f"CUDA available: {torch.cuda.is_available()}") try: # Load real data processor = AdvancedNarrativeProcessor( annotations_file="../../data/subtask-2-annotations.txt", raw_dir="../../data/raw" ) processed_data = processor.load_and_process_data() # Create dataset and dataloader train_dataset = NarrativeDataset(processed_data['train']) train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) # Initialize model model = NarrativeClassifier(num_labels=train_dataset.get_num_labels()) logger.info(f"Model initialized on device: {next(model.parameters()).device}") # Test with real batch for batch in train_loader: batch = {k: v.to(model.device) for k, v in batch.items()} outputs = model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], features=batch['features'], return_dict=True ) logger.info("\n=== Model Test Results ===") logger.info(f"Input shape: {batch['input_ids'].shape}") logger.info(f"Output logits shape: {outputs['logits'].shape}") logger.info(f"Hidden states shape: {outputs['hidden_states'].shape}") logger.info("Forward pass successful!") break # Test only one batch except Exception as e: logger.error(f"Error during model test: {str(e)}") raise