# File: scripts/models/dataset.py from torch.utils.data import Dataset import torch from typing import Dict, Any from pathlib import Path class NarrativeDataset(Dataset): """ Dataset class for narrative classification. Handles the data after preprocessing for model training. """ def __init__(self, data_dict: Dict[str, Any]): """ Initialize the dataset with processed data. Args: data_dict: Dictionary containing processed data from AdvancedNarrativeProcessor """ self.input_ids = data_dict['input_ids'] self.attention_mask = data_dict['attention_mask'] # Convert labels and features to float self.labels = data_dict['labels'].float() self.features = data_dict['features'].float() # Verify data consistency assert len(self.input_ids) == len(self.labels), \ "Mismatch between inputs and labels length" def __len__(self) -> int: """Return the total number of samples.""" return len(self.input_ids) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """ Get a single sample from the dataset. Args: idx: Index of the sample to get Returns: Dictionary containing all features for the sample """ return { 'input_ids': self.input_ids[idx], 'attention_mask': self.attention_mask[idx], 'labels': self.labels[idx], 'features': self.features[idx] } def get_num_labels(self) -> int: """Return the number of labels in the dataset.""" return self.labels.shape[1] # Real test with our preprocessed data if __name__ == "__main__": # Import our preprocessor import sys sys.path.append("../../") # Add root to path from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor # Initialize preprocessor processor = AdvancedNarrativeProcessor( annotations_file="../../data/subtask-2-annotations.txt", raw_dir="../../data/raw" ) # Get processed data processed_data = processor.load_and_process_data() # Create train and validation datasets train_dataset = NarrativeDataset(processed_data['train']) val_dataset = NarrativeDataset(processed_data['val']) # Print information about the datasets print("\n=== Dataset Statistics ===") print(f"Training samples: {len(train_dataset)}") print(f"Validation samples: {len(val_dataset)}") print(f"Number of labels: {train_dataset.get_num_labels()}") # Look at a sample sample = train_dataset[0] print("\n=== Sample Details ===") print(f"Input IDs shape: {sample['input_ids'].shape}") print(f"Attention mask shape: {sample['attention_mask'].shape}") print(f"Labels shape: {sample['labels'].shape}") print(f"Features shape: {sample['features'].shape}")