Spaces:
Runtime error
Runtime error
import sys | |
import logging | |
from pathlib import Path | |
from transformers import set_seed | |
# Import the necessary modules from your project | |
sys.path.append("./scripts") # Adjust path if needed | |
from scripts.models.model import NarrativeClassifier | |
from scripts.models.dataset import NarrativeDataset | |
from scripts.config.config import TrainingConfig | |
from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor | |
from scripts.training.trainer import NarrativeTrainer | |
def main(): | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info("Initializing training process...") | |
import os | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info("Initializing training process...") | |
import os | |
import spacy | |
# Download and load SpaCy model dynamically | |
try: | |
spacy.load("en_core_web_sm") | |
except OSError: | |
logger.info("Downloading SpaCy model 'en_core_web_sm'...") | |
os.system("python -m spacy download en_core_web_sm") | |
# Set a random seed for reproducibility | |
set_seed(42) | |
# Load and process the dataset | |
annotations_file = "./data/subtask-2-annotations.txt" # Adjust path as needed | |
raw_dir = "./data/raw" # Adjust path as needed | |
logger.info("Loading and processing dataset...") | |
processor = AdvancedNarrativeProcessor( | |
annotations_file=annotations_file, | |
raw_dir=raw_dir | |
) | |
processed_data = processor.load_and_process_data() | |
# Split processed data into training and validation sets | |
train_dataset = NarrativeDataset(processed_data['train']) | |
val_dataset = NarrativeDataset(processed_data['val']) | |
logger.info(f"Loaded dataset with {len(train_dataset)} training samples and {len(val_dataset)} validation samples.") | |
# Initialize the model | |
logger.info("Initializing the model...") | |
model = NarrativeClassifier(num_labels=train_dataset.get_num_labels()) | |
# Define training configuration | |
config = TrainingConfig( | |
output_dir=Path("./output"), # Save outputs in this directory | |
num_epochs=5, | |
batch_size=16, | |
learning_rate=2e-5, | |
warmup_ratio=0.1, | |
weight_decay=0.01, | |
max_grad_norm=1.0, | |
eval_steps=100, | |
save_steps=100 | |
) | |
logger.info(f"Training configuration: {config}") | |
# Initialize the trainer | |
trainer = NarrativeTrainer( | |
model=model, | |
train_dataset=train_dataset, | |
val_dataset=val_dataset, | |
config=config | |
) | |
# Start the training process | |
logger.info("Starting the training process...") | |
trainer.train() | |
logger.info("Training completed successfully!") | |
if __name__ == "__main__": | |
main() | |