Spaces:
Runtime error
Runtime error
# File: scripts/models/dataset.py | |
from torch.utils.data import Dataset | |
import torch | |
from typing import Dict, Any | |
from pathlib import Path | |
class NarrativeDataset(Dataset): | |
""" | |
Dataset class for narrative classification. | |
Handles the data after preprocessing for model training. | |
""" | |
def __init__(self, data_dict: Dict[str, Any]): | |
""" | |
Initialize the dataset with processed data. | |
Args: | |
data_dict: Dictionary containing processed data from AdvancedNarrativeProcessor | |
""" | |
self.input_ids = data_dict['input_ids'] | |
self.attention_mask = data_dict['attention_mask'] | |
# Convert labels and features to float | |
self.labels = data_dict['labels'].float() | |
self.features = data_dict['features'].float() | |
# Verify data consistency | |
assert len(self.input_ids) == len(self.labels), \ | |
"Mismatch between inputs and labels length" | |
def __len__(self) -> int: | |
"""Return the total number of samples.""" | |
return len(self.input_ids) | |
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
""" | |
Get a single sample from the dataset. | |
Args: | |
idx: Index of the sample to get | |
Returns: | |
Dictionary containing all features for the sample | |
""" | |
return { | |
'input_ids': self.input_ids[idx], | |
'attention_mask': self.attention_mask[idx], | |
'labels': self.labels[idx], | |
'features': self.features[idx] | |
} | |
def get_num_labels(self) -> int: | |
"""Return the number of labels in the dataset.""" | |
return self.labels.shape[1] | |
# Real test with our preprocessed data | |
if __name__ == "__main__": | |
# Import our preprocessor | |
import sys | |
sys.path.append("../../") # Add root to path | |
from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor | |
# Initialize preprocessor | |
processor = AdvancedNarrativeProcessor( | |
annotations_file="../../data/subtask-2-annotations.txt", | |
raw_dir="../../data/raw" | |
) | |
# Get processed data | |
processed_data = processor.load_and_process_data() | |
# Create train and validation datasets | |
train_dataset = NarrativeDataset(processed_data['train']) | |
val_dataset = NarrativeDataset(processed_data['val']) | |
# Print information about the datasets | |
print("\n=== Dataset Statistics ===") | |
print(f"Training samples: {len(train_dataset)}") | |
print(f"Validation samples: {len(val_dataset)}") | |
print(f"Number of labels: {train_dataset.get_num_labels()}") | |
# Look at a sample | |
sample = train_dataset[0] | |
print("\n=== Sample Details ===") | |
print(f"Input IDs shape: {sample['input_ids'].shape}") | |
print(f"Attention mask shape: {sample['attention_mask'].shape}") | |
print(f"Labels shape: {sample['labels'].shape}") | |
print(f"Features shape: {sample['features'].shape}") |