Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import DataLoader | |
from transformers import get_linear_schedule_with_warmup | |
from tqdm import tqdm | |
import logging | |
from pathlib import Path | |
import numpy as np | |
from sklearn.metrics import f1_score, precision_score, recall_score | |
import json | |
from datetime import datetime | |
class NarrativeTrainer: | |
""" | |
Comprehensive trainer for narrative classification with GPU support. | |
""" | |
def __init__( | |
self, | |
model, | |
train_dataset, | |
val_dataset, | |
config, | |
): | |
self.setup_logging() | |
self.logger = logging.getLogger(__name__) | |
# Set device | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.logger.info(f"Using device: {self.device}") | |
# Initialize model and components | |
self.model = model.to(self.device) | |
self.train_dataset = train_dataset | |
self.val_dataset = val_dataset | |
self.config = config | |
self.current_epoch = 0 | |
self.global_step = 0 | |
self.best_val_f1 = 0.0 | |
self.setup_training() | |
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
self.output_dir = Path(config.output_dir) / self.timestamp | |
self.output_dir.mkdir(parents=True, exist_ok=True) | |
self.save_config() | |
self.history = { | |
'train_loss': [], | |
'val_loss': [], | |
'val_f1': [], | |
'val_precision': [], | |
'val_recall': [] | |
} | |
def setup_logging(self): | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
def setup_training(self): | |
"""Initialize dataloaders, optimizer, and scheduler.""" | |
self.train_loader = DataLoader( | |
self.train_dataset, | |
batch_size=self.config.batch_size, | |
shuffle=True, | |
num_workers=4 | |
) | |
self.val_loader = DataLoader( | |
self.val_dataset, | |
batch_size=self.config.batch_size, | |
num_workers=4 | |
) | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=self.config.learning_rate, | |
weight_decay=self.config.weight_decay | |
) | |
num_training_steps = len(self.train_loader) * self.config.num_epochs | |
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio) | |
self.scheduler = get_linear_schedule_with_warmup( | |
self.optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps | |
) | |
self.criterion = torch.nn.BCEWithLogitsLoss() | |
def save_config(self): | |
"""Save training configuration.""" | |
config_dict = {k: str(v) for k, v in vars(self.config).items()} | |
config_path = self.output_dir / 'config.json' | |
with open(config_path, 'w') as f: | |
json.dump(config_dict, f, indent=4) | |
def train_epoch(self): | |
"""Train model for one epoch.""" | |
self.model.train() | |
total_loss = 0 | |
pbar = tqdm(self.train_loader, desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}') | |
for batch in pbar: | |
batch = {k: v.to(self.device) for k, v in batch.items()} | |
self.optimizer.zero_grad() | |
outputs = self.model( | |
input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], | |
features=batch['features'] | |
) | |
loss = self.criterion(outputs, batch['labels']) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) | |
self.optimizer.step() | |
self.scheduler.step() | |
total_loss += loss.item() | |
pbar.set_postfix({'loss': total_loss / (pbar.n + 1)}) | |
self.global_step += 1 | |
if self.global_step % self.config.eval_steps == 0: | |
self.evaluate() | |
return total_loss / len(self.train_loader) | |
def evaluate(self): | |
"""Evaluate model performance.""" | |
self.model.eval() | |
total_loss = 0 | |
all_preds, all_labels = [], [] | |
for batch in tqdm(self.val_loader, desc="Evaluating"): | |
batch = {k: v.to(self.device) for k, v in batch.items()} | |
outputs = self.model( | |
input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], | |
features=batch['features'] | |
) | |
loss = self.criterion(outputs, batch['labels']) | |
total_loss += loss.item() | |
preds = torch.sigmoid(outputs) > 0.5 | |
all_preds.append(preds.cpu().numpy()) | |
all_labels.append(batch['labels'].cpu().numpy()) | |
all_preds = np.concatenate(all_preds, axis=0) | |
all_labels = np.concatenate(all_labels, axis=0) | |
metrics = { | |
'loss': total_loss / len(self.val_loader), | |
'f1': f1_score(all_labels, all_preds, average='micro'), | |
'precision': precision_score(all_labels, all_preds, average='micro'), | |
'recall': recall_score(all_labels, all_preds, average='micro') | |
} | |
self.logger.info(f"Step {self.global_step} - Validation metrics: {metrics}") | |
if metrics['f1'] > self.best_val_f1: | |
self.best_val_f1 = metrics['f1'] | |
self.save_model('best_model.pt', metrics) | |
return metrics | |
def save_model(self, filename: str, metrics: dict = None): | |
save_path = self.output_dir / filename | |
torch.save({ | |
'model_state_dict': self.model.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'scheduler_state_dict': self.scheduler.state_dict(), | |
'epoch': self.current_epoch, | |
'global_step': self.global_step, | |
'best_val_f1': self.best_val_f1, | |
'metrics': metrics | |
}, save_path) | |
self.logger.info(f"Model saved to {save_path}") | |
def train(self): | |
"""Run training for all epochs.""" | |
self.logger.info("Starting training...") | |
for epoch in range(self.config.num_epochs): | |
self.current_epoch = epoch | |
train_loss = self.train_epoch() | |
self.history['train_loss'].append(train_loss) | |
val_metrics = self.evaluate() | |
self.history['val_loss'].append(val_metrics['loss']) | |
self.history['val_f1'].append(val_metrics['f1']) | |
self.history['val_precision'].append(val_metrics['precision']) | |
self.history['val_recall'].append(val_metrics['recall']) | |
self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics) | |
history_path = self.output_dir / 'history.json' | |
with open(history_path, 'w') as f: | |
json.dump(self.history, f, indent=4) | |
self.logger.info("Training completed!") | |
return self.history | |
if __name__ == "__main__": | |
import sys | |
sys.path.append("../../") | |
from scripts.models.model import NarrativeClassifier | |
from scripts.models.dataset import NarrativeDataset | |
from scripts.config.config import TrainingConfig | |
from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor | |
# Initialize training configuration | |
config = TrainingConfig( | |
output_dir=Path("./output"), | |
num_epochs=5, | |
batch_size=32, | |
learning_rate=5e-5, | |
weight_decay=0.01, | |
warmup_ratio=0.1, | |
max_grad_norm=1.0, | |
eval_steps=100 | |
) | |
# Load and process data | |
processor = AdvancedNarrativeProcessor( | |
annotations_file="../../data/subtask-2-annotations.txt", | |
raw_dir="../../data/raw" | |
) | |
processed_data = processor.load_and_process_data() | |
# Create datasets | |
train_dataset = NarrativeDataset(processed_data['train']) | |
val_dataset = NarrativeDataset(processed_data['val']) | |
# Initialize model | |
model = NarrativeClassifier(num_labels=train_dataset.get_num_labels()) | |
# Initialize trainer | |
trainer = NarrativeTrainer( | |
model=model, | |
train_dataset=train_dataset, | |
val_dataset=val_dataset, | |
config=config | |
) | |
# Start full training | |
print("\n=== Starting Training ===") | |
trainer.train() | |
print("\nTraining completed successfully!") | |