Mohammaderfan koupaei
Add application file
fb2cd67
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from typing import Dict, Optional
import logging
from torch.utils.data import Dataset, DataLoader
# NarrativeClassifier Model Definition
class NarrativeClassifier(nn.Module):
"""
Production-ready model for narrative classification combining transformer with additional features.
"""
def __init__(
self,
model_name: str = "microsoft/deberta-v3-large",
num_labels: int = 84,
dropout: float = 0.1,
freeze_encoder: bool = False,
device: Optional[str] = None
):
super().__init__()
self.logger = logging.getLogger(__name__)
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
self.logger.info(f"Using device: {self.device}")
self.config = AutoConfig.from_pretrained(model_name)
try:
self.transformer = AutoModel.from_pretrained(model_name, config=self.config)
except Exception as e:
self.logger.error(f"Error loading transformer model: {str(e)}")
raise
if freeze_encoder:
self.logger.info("Freezing transformer encoder")
for param in self.transformer.parameters():
param.requires_grad = False
self.transformer_dim = self.transformer.config.hidden_size
self.num_features = 5 # Additional numerical features
self.feature_processor = nn.Sequential(
nn.Linear(self.num_features, 64),
nn.LayerNorm(64),
nn.ReLU(),
nn.Dropout(dropout)
)
self.pre_classifier = nn.Sequential(
nn.Linear(self.transformer_dim + 64, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Dropout(dropout)
)
self.classifier = nn.Linear(512, num_labels)
self._init_weights()
self.to(self.device)
def _init_weights(self):
"""Initialize weights for added layers."""
for module in [self.feature_processor, self.pre_classifier, self.classifier]:
for layer in module.modules():
if isinstance(layer, nn.Linear):
torch.nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def forward(self, input_ids, attention_mask, features, return_dict=False):
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
sequence_output = transformer_outputs.last_hidden_state[:, 0, :]
processed_features = self.feature_processor(features)
combined = torch.cat([sequence_output, processed_features], dim=1)
intermediate = self.pre_classifier(combined)
logits = self.classifier(intermediate)
if return_dict:
return {
'logits': logits,
'hidden_states': intermediate,
'transformer_output': sequence_output
}
return logits
# Dataset Definition
class NarrativeDataset(Dataset):
def __init__(self, data):
"""
Initialize dataset with data.
Args:
data: List of dictionaries containing input_ids, attention_mask, and features.
"""
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
"""
Return one sample.
"""
sample = self.data[idx]
return {
'input_ids': torch.tensor(sample['input_ids'], dtype=torch.long),
'attention_mask': torch.tensor(sample['attention_mask'], dtype=torch.long),
'features': torch.tensor(sample['features'], dtype=torch.float)
}
def get_num_labels(self):
"""
Return the number of labels (for classification tasks).
"""
return max(item['label'] for item in self.data) + 1
# Main Testing Section
if __name__ == "__main__":
import sys
sys.path.append("../../")
from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor
from scripts.models.dataset import NarrativeDataset
from torch.utils.data import DataLoader
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info(f"CUDA available: {torch.cuda.is_available()}")
try:
# Load real data
processor = AdvancedNarrativeProcessor(
annotations_file="../../data/subtask-2-annotations.txt",
raw_dir="../../data/raw"
)
processed_data = processor.load_and_process_data()
# Create dataset and dataloader
train_dataset = NarrativeDataset(processed_data['train'])
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# Initialize model
model = NarrativeClassifier(num_labels=train_dataset.get_num_labels())
logger.info(f"Model initialized on device: {next(model.parameters()).device}")
# Test with real batch
for batch in train_loader:
batch = {k: v.to(model.device) for k, v in batch.items()}
outputs = model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
features=batch['features'],
return_dict=True
)
logger.info("\n=== Model Test Results ===")
logger.info(f"Input shape: {batch['input_ids'].shape}")
logger.info(f"Output logits shape: {outputs['logits'].shape}")
logger.info(f"Hidden states shape: {outputs['hidden_states'].shape}")
logger.info("Forward pass successful!")
break # Test only one batch
except Exception as e:
logger.error(f"Error during model test: {str(e)}")
raise