File size: 2,819 Bytes
fb2cd67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941a5b8
6b418f0
941a5b8
 
 
 
 
5033ac0
 
941a5b8
5033ac0
 
 
 
 
 
 
fb2cd67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import logging
from pathlib import Path
from transformers import set_seed

# Import the necessary modules from your project
sys.path.append("./scripts")  # Adjust path if needed
from scripts.models.model import NarrativeClassifier
from scripts.models.dataset import NarrativeDataset
from scripts.config.config import TrainingConfig
from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor
from scripts.training.trainer import NarrativeTrainer

def main():
    # Set up logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    logger.info("Initializing training process...")
    import os
   
    
    # Set up logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    logger.info("Initializing training process...")
    import os
    import spacy
    
    # Download and load SpaCy model dynamically
    try:
        spacy.load("en_core_web_sm")
    except OSError:
        logger.info("Downloading SpaCy model 'en_core_web_sm'...")
        os.system("python -m spacy download en_core_web_sm")

    # Set a random seed for reproducibility
    set_seed(42)
    
    # Load and process the dataset
    annotations_file = "./data/subtask-2-annotations.txt"  # Adjust path as needed
    raw_dir = "./data/raw"  # Adjust path as needed
    logger.info("Loading and processing dataset...")
    
    processor = AdvancedNarrativeProcessor(
        annotations_file=annotations_file,
        raw_dir=raw_dir
    )
    processed_data = processor.load_and_process_data()
    
    # Split processed data into training and validation sets
    train_dataset = NarrativeDataset(processed_data['train'])
    val_dataset = NarrativeDataset(processed_data['val'])
    logger.info(f"Loaded dataset with {len(train_dataset)} training samples and {len(val_dataset)} validation samples.")
    
    # Initialize the model
    logger.info("Initializing the model...")
    model = NarrativeClassifier(num_labels=train_dataset.get_num_labels())
    
    # Define training configuration
    config = TrainingConfig(
        output_dir=Path("./output"),  # Save outputs in this directory
        num_epochs=5,
        batch_size=16,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        weight_decay=0.01,
        max_grad_norm=1.0,
        eval_steps=100,
        save_steps=100
    )
    logger.info(f"Training configuration: {config}")
    
    # Initialize the trainer
    trainer = NarrativeTrainer(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        config=config
    )
    
    # Start the training process
    logger.info("Starting the training process...")
    trainer.train()
    logger.info("Training completed successfully!")

if __name__ == "__main__":
    main()