Mohammaderfan koupaei
Add application file
fb2cd67
import torch
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import logging
from pathlib import Path
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
import json
from datetime import datetime
class NarrativeTrainer:
"""
Comprehensive trainer for narrative classification with GPU support.
"""
def __init__(
self,
model,
train_dataset,
val_dataset,
config,
):
self.setup_logging()
self.logger = logging.getLogger(__name__)
# Set device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.logger.info(f"Using device: {self.device}")
# Initialize model and components
self.model = model.to(self.device)
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.config = config
self.current_epoch = 0
self.global_step = 0
self.best_val_f1 = 0.0
self.setup_training()
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.output_dir = Path(config.output_dir) / self.timestamp
self.output_dir.mkdir(parents=True, exist_ok=True)
self.save_config()
self.history = {
'train_loss': [],
'val_loss': [],
'val_f1': [],
'val_precision': [],
'val_recall': []
}
def setup_logging(self):
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
def setup_training(self):
"""Initialize dataloaders, optimizer, and scheduler."""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
num_workers=4
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.config.batch_size,
num_workers=4
)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
num_training_steps = len(self.train_loader) * self.config.num_epochs
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
self.criterion = torch.nn.BCEWithLogitsLoss()
def save_config(self):
"""Save training configuration."""
config_dict = {k: str(v) for k, v in vars(self.config).items()}
config_path = self.output_dir / 'config.json'
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)
def train_epoch(self):
"""Train model for one epoch."""
self.model.train()
total_loss = 0
pbar = tqdm(self.train_loader, desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}')
for batch in pbar:
batch = {k: v.to(self.device) for k, v in batch.items()}
self.optimizer.zero_grad()
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
features=batch['features']
)
loss = self.criterion(outputs, batch['labels'])
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item()
pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})
self.global_step += 1
if self.global_step % self.config.eval_steps == 0:
self.evaluate()
return total_loss / len(self.train_loader)
@torch.no_grad()
def evaluate(self):
"""Evaluate model performance."""
self.model.eval()
total_loss = 0
all_preds, all_labels = [], []
for batch in tqdm(self.val_loader, desc="Evaluating"):
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
features=batch['features']
)
loss = self.criterion(outputs, batch['labels'])
total_loss += loss.item()
preds = torch.sigmoid(outputs) > 0.5
all_preds.append(preds.cpu().numpy())
all_labels.append(batch['labels'].cpu().numpy())
all_preds = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
metrics = {
'loss': total_loss / len(self.val_loader),
'f1': f1_score(all_labels, all_preds, average='micro'),
'precision': precision_score(all_labels, all_preds, average='micro'),
'recall': recall_score(all_labels, all_preds, average='micro')
}
self.logger.info(f"Step {self.global_step} - Validation metrics: {metrics}")
if metrics['f1'] > self.best_val_f1:
self.best_val_f1 = metrics['f1']
self.save_model('best_model.pt', metrics)
return metrics
def save_model(self, filename: str, metrics: dict = None):
save_path = self.output_dir / filename
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'epoch': self.current_epoch,
'global_step': self.global_step,
'best_val_f1': self.best_val_f1,
'metrics': metrics
}, save_path)
self.logger.info(f"Model saved to {save_path}")
def train(self):
"""Run training for all epochs."""
self.logger.info("Starting training...")
for epoch in range(self.config.num_epochs):
self.current_epoch = epoch
train_loss = self.train_epoch()
self.history['train_loss'].append(train_loss)
val_metrics = self.evaluate()
self.history['val_loss'].append(val_metrics['loss'])
self.history['val_f1'].append(val_metrics['f1'])
self.history['val_precision'].append(val_metrics['precision'])
self.history['val_recall'].append(val_metrics['recall'])
self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics)
history_path = self.output_dir / 'history.json'
with open(history_path, 'w') as f:
json.dump(self.history, f, indent=4)
self.logger.info("Training completed!")
return self.history
if __name__ == "__main__":
import sys
sys.path.append("../../")
from scripts.models.model import NarrativeClassifier
from scripts.models.dataset import NarrativeDataset
from scripts.config.config import TrainingConfig
from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor
# Initialize training configuration
config = TrainingConfig(
output_dir=Path("./output"),
num_epochs=5,
batch_size=32,
learning_rate=5e-5,
weight_decay=0.01,
warmup_ratio=0.1,
max_grad_norm=1.0,
eval_steps=100
)
# Load and process data
processor = AdvancedNarrativeProcessor(
annotations_file="../../data/subtask-2-annotations.txt",
raw_dir="../../data/raw"
)
processed_data = processor.load_and_process_data()
# Create datasets
train_dataset = NarrativeDataset(processed_data['train'])
val_dataset = NarrativeDataset(processed_data['val'])
# Initialize model
model = NarrativeClassifier(num_labels=train_dataset.get_num_labels())
# Initialize trainer
trainer = NarrativeTrainer(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
config=config
)
# Start full training
print("\n=== Starting Training ===")
trainer.train()
print("\nTraining completed successfully!")