Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import DataLoader | |
from transformers import get_linear_schedule_with_warmup | |
from tqdm import tqdm | |
import logging | |
from pathlib import Path | |
import numpy as np | |
from sklearn.metrics import f1_score, precision_score, recall_score | |
import json | |
from datetime import datetime | |
from torch.cuda.amp import autocast, GradScaler | |
class NarrativeTrainer: | |
"""Enhanced trainer with detailed metrics and optimizations""" | |
def __init__( | |
self, | |
model, | |
train_dataset, | |
val_dataset, | |
config, | |
): | |
# Setup basics | |
self.setup_logging() | |
self.logger = logging.getLogger(__name__) | |
# Store config first | |
self.config = config | |
# Setup device | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.logger.info(f"Using device: {self.device}") | |
# Clear GPU cache if using CUDA | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Initialize model and components | |
self.model = model.to(self.device) | |
self.train_dataset = train_dataset | |
self.val_dataset = val_dataset | |
# Initialize training state | |
self.current_epoch = 0 | |
self.global_step = 0 | |
self.best_val_f1 = 0.0 | |
# Initialize mixed precision training (Fixed version) | |
if self.config.fp16: | |
self.scaler = torch.cuda.amp.GradScaler() | |
else: | |
self.scaler = None | |
# Setup training components | |
self.setup_training() | |
# Setup output directory | |
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
self.output_dir = Path(config.output_dir) / self.timestamp | |
self.output_dir.mkdir(parents=True, exist_ok=True) | |
# Save config and initialize history | |
self.save_config() | |
self.history = { | |
'train_loss': [], | |
'val_loss': [], | |
'metrics': [], | |
'thresholds': [] | |
} | |
def setup_logging(self): | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
def calculate_class_weights(self): | |
"""Calculate weights for imbalanced classes""" | |
pos_counts = self.train_dataset.labels.sum(dim=0) | |
neg_counts = len(self.train_dataset) - pos_counts | |
pos_weight = (neg_counts / pos_counts) * self.config.pos_weight_multiplier | |
return torch.clamp(pos_weight, min=1.0, max=50.0).to(self.device) | |
def setup_training(self): | |
"""Initialize training components with optimizations""" | |
# Create dataloaders | |
self.train_loader = DataLoader( | |
self.train_dataset, | |
batch_size=self.config.batch_size, | |
shuffle=True, | |
num_workers=4, | |
pin_memory=True | |
) | |
self.val_loader = DataLoader( | |
self.val_dataset, | |
batch_size=self.config.batch_size, | |
num_workers=4, | |
pin_memory=True | |
) | |
# Calculate class weights | |
pos_weight = self.calculate_class_weights() | |
# Setup loss function with class weights only | |
self.criterion = torch.nn.BCEWithLogitsLoss( | |
pos_weight=pos_weight | |
) | |
# Setup optimizer | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=self.config.learning_rate, | |
weight_decay=self.config.weight_decay | |
) | |
# Setup scheduler | |
num_update_steps_per_epoch = len(self.train_loader) // self.config.gradient_accumulation_steps | |
num_training_steps = num_update_steps_per_epoch * self.config.num_epochs | |
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio) | |
self.scheduler = get_linear_schedule_with_warmup( | |
self.optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps | |
) | |
# Initialize thresholds | |
self.label_thresholds = torch.ones(self.train_dataset.get_num_labels()).to(self.device) * 0.5 | |
def save_config(self): | |
config_dict = {k: str(v) for k, v in vars(self.config).items()} | |
config_path = self.output_dir / 'config.json' | |
with open(config_path, 'w') as f: | |
json.dump(config_dict, f, indent=4) | |
def find_optimal_thresholds(self, val_outputs, val_labels): | |
"""Find optimal threshold for each label""" | |
outputs = torch.sigmoid(val_outputs).cpu().numpy() | |
labels = val_labels.cpu().numpy() | |
thresholds = [] | |
for i in range(labels.shape[1]): | |
best_f1 = 0 | |
best_threshold = 0.5 | |
if labels[:, i].sum() > 0: # Only if we have positive samples | |
for threshold in np.arange(0.1, 0.9, 0.05): | |
preds = (outputs[:, i] > threshold).astype(int) | |
f1 = f1_score(labels[:, i], preds) | |
if f1 > best_f1: | |
best_f1 = f1 | |
best_threshold = threshold | |
thresholds.append(best_threshold) | |
return torch.tensor(thresholds).to(self.device) | |
def calculate_detailed_metrics(self, all_labels, all_preds, all_probs=None): | |
"""Calculate detailed metrics for model evaluation""" | |
metrics = {} | |
# Basic metrics | |
metrics['micro'] = { | |
'precision': precision_score(all_labels, all_preds, average='micro'), | |
'recall': recall_score(all_labels, all_preds, average='micro'), | |
'f1': f1_score(all_labels, all_preds, average='micro') | |
} | |
metrics['macro'] = { | |
'precision': precision_score(all_labels, all_preds, average='macro'), | |
'recall': recall_score(all_labels, all_preds, average='macro'), | |
'f1': f1_score(all_labels, all_preds, average='macro') | |
} | |
metrics['weighted'] = { | |
'precision': precision_score(all_labels, all_preds, average='weighted'), | |
'recall': recall_score(all_labels, all_preds, average='weighted'), | |
'f1': f1_score(all_labels, all_preds, average='weighted') | |
} | |
# Per-class metrics | |
per_class_metrics = {} | |
precisions = precision_score(all_labels, all_preds, average=None) | |
recalls = recall_score(all_labels, all_preds, average=None) | |
f1s = f1_score(all_labels, all_preds, average=None) | |
for i in range(len(f1s)): | |
per_class_metrics[f'class_{i}'] = { | |
'precision': float(precisions[i]), | |
'recall': float(recalls[i]), | |
'f1': float(f1s[i]), | |
'support': int(all_labels[:, i].sum()) | |
} | |
metrics['per_class'] = per_class_metrics | |
return metrics | |
def train_epoch(self): | |
"""Train for one epoch with optimizations""" | |
self.model.train() | |
total_loss = 0 | |
self.optimizer.zero_grad() | |
pbar = tqdm(enumerate(self.train_loader), | |
total=len(self.train_loader), | |
desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}') | |
for step, batch in pbar: | |
batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()} | |
# Mixed precision training | |
with torch.cuda.amp.autocast(enabled=self.config.fp16): | |
outputs = self.model( | |
input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], | |
features=batch['features'] | |
) | |
loss = self.criterion(outputs, batch['labels']) | |
loss = loss / self.config.gradient_accumulation_steps | |
# Backward pass with scaler if fp16 is enabled | |
if self.config.fp16: | |
self.scaler.scale(loss).backward() | |
else: | |
loss.backward() | |
if (step + 1) % self.config.gradient_accumulation_steps == 0: | |
if self.config.fp16: | |
self.scaler.unscale_(self.optimizer) | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), | |
self.config.max_grad_norm | |
) | |
if self.config.fp16: | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
else: | |
self.optimizer.step() | |
self.scheduler.step() | |
self.optimizer.zero_grad() | |
total_loss += loss.item() * self.config.gradient_accumulation_steps | |
avg_loss = total_loss / (step + 1) | |
pbar.set_postfix({'loss': f'{avg_loss:.4f}'}) | |
self.global_step += 1 | |
if self.global_step % self.config.eval_steps == 0: | |
self.evaluate() | |
if step % 10 == 0: | |
torch.cuda.empty_cache() | |
del outputs | |
del loss | |
return total_loss / len(self.train_loader) | |
def evaluate(self): | |
"""Evaluate model with detailed metrics""" | |
self.model.eval() | |
total_loss = 0 | |
all_outputs, all_labels = [], [] | |
for batch in tqdm(self.val_loader, desc="Evaluating"): | |
batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()} | |
with autocast(enabled=self.config.fp16): | |
outputs = self.model( | |
input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], | |
features=batch['features'] | |
) | |
loss = self.criterion(outputs, batch['labels']) | |
total_loss += loss.item() | |
all_outputs.append(outputs.cpu()) | |
all_labels.append(batch['labels'].cpu()) | |
del outputs | |
del loss | |
torch.cuda.empty_cache() | |
all_outputs = torch.cat(all_outputs, dim=0) | |
all_labels = torch.cat(all_labels, dim=0) | |
if self.global_step % (self.config.eval_steps * 2) == 0: | |
self.label_thresholds = self.find_optimal_thresholds(all_outputs, all_labels) | |
all_probs = torch.sigmoid(all_outputs).numpy() | |
all_preds = (all_probs > self.label_thresholds.cpu().unsqueeze(0).numpy()) | |
all_labels = all_labels.numpy() | |
metrics = self.calculate_detailed_metrics(all_labels, all_preds, all_probs) | |
metrics['loss'] = total_loss / len(self.val_loader) | |
self.logger.info(f"Step {self.global_step} - Validation metrics:") | |
self.logger.info(f"Loss: {metrics['loss']:.4f}") | |
self.logger.info(f"Micro F1: {metrics['micro']['f1']:.4f}") | |
self.logger.info(f"Macro F1: {metrics['macro']['f1']:.4f}") | |
if metrics['micro']['f1'] > self.best_val_f1: | |
self.best_val_f1 = metrics['micro']['f1'] | |
self.save_model('best_model.pt', metrics) | |
return metrics | |
def save_model(self, filename: str, metrics: dict = None): | |
save_path = self.output_dir / filename | |
torch.save({ | |
'model_state_dict': self.model.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'scheduler_state_dict': self.scheduler.state_dict(), | |
'scaler_state_dict': self.scaler.state_dict(), | |
'epoch': self.current_epoch, | |
'global_step': self.global_step, | |
'best_val_f1': self.best_val_f1, | |
'metrics': metrics, | |
'thresholds': self.label_thresholds | |
}, save_path) | |
self.logger.info(f"Model saved to {save_path}") | |
def train(self): | |
"""Run complete training loop""" | |
self.logger.info("Starting training...") | |
try: | |
for epoch in range(self.config.num_epochs): | |
self.current_epoch = epoch | |
self.logger.info(f"Starting epoch {epoch + 1}/{self.config.num_epochs}") | |
train_loss = self.train_epoch() | |
self.history['train_loss'].append(train_loss) | |
val_metrics = self.evaluate() | |
self.history['metrics'].append(val_metrics) | |
self.history['thresholds'].append(self.label_thresholds.cpu().tolist()) | |
self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics) | |
history_path = self.output_dir / 'history.json' | |
with open(history_path, 'w') as f: | |
json.dump(self.history, f, indent=4) | |
self.logger.info(f"Epoch {epoch + 1} completed. Train loss: {train_loss:.4f}") | |
self.logger.info("Training completed successfully!") | |
return self.history | |
except Exception as e: | |
self.logger.error(f"Training failed with error: {str(e)}") | |
raise |