import sys import logging from pathlib import Path from transformers import set_seed # Import the necessary modules from your project sys.path.append("./scripts") # Adjust path if needed from scripts.models.model import NarrativeClassifier from scripts.models.dataset import NarrativeDataset from scripts.config.config import TrainingConfig from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor from scripts.training.trainer import NarrativeTrainer def main(): # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info("Initializing training process...") import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info("Initializing training process...") import os import spacy # Download and load SpaCy model dynamically try: spacy.load("en_core_web_sm") except OSError: logger.info("Downloading SpaCy model 'en_core_web_sm'...") os.system("python -m spacy download en_core_web_sm") # Set a random seed for reproducibility set_seed(42) # Load and process the dataset annotations_file = "./data/subtask-2-annotations.txt" # Adjust path as needed raw_dir = "./data/raw" # Adjust path as needed logger.info("Loading and processing dataset...") processor = AdvancedNarrativeProcessor( annotations_file=annotations_file, raw_dir=raw_dir ) processed_data = processor.load_and_process_data() # Split processed data into training and validation sets train_dataset = NarrativeDataset(processed_data['train']) val_dataset = NarrativeDataset(processed_data['val']) logger.info(f"Loaded dataset with {len(train_dataset)} training samples and {len(val_dataset)} validation samples.") # Initialize the model logger.info("Initializing the model...") model = NarrativeClassifier(num_labels=train_dataset.get_num_labels()) # Define training configuration config = TrainingConfig( output_dir=Path("./output"), # Save outputs in this directory num_epochs=5, batch_size=16, learning_rate=2e-5, warmup_ratio=0.1, weight_decay=0.01, max_grad_norm=1.0, eval_steps=100, save_steps=100 ) logger.info(f"Training configuration: {config}") # Initialize the trainer trainer = NarrativeTrainer( model=model, train_dataset=train_dataset, val_dataset=val_dataset, config=config ) # Start the training process logger.info("Starting the training process...") trainer.train() logger.info("Training completed successfully!") if __name__ == "__main__": main()