Mohammaderfan koupaei
second
660777d
import torch
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import logging
from pathlib import Path
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
import json
from datetime import datetime
from torch.cuda.amp import autocast, GradScaler
class NarrativeTrainer:
"""Enhanced trainer with detailed metrics and optimizations"""
def __init__(
self,
model,
train_dataset,
val_dataset,
config,
):
# Setup basics
self.setup_logging()
self.logger = logging.getLogger(__name__)
# Store config first
self.config = config
# Setup device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.logger.info(f"Using device: {self.device}")
# Clear GPU cache if using CUDA
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Initialize model and components
self.model = model.to(self.device)
self.train_dataset = train_dataset
self.val_dataset = val_dataset
# Initialize training state
self.current_epoch = 0
self.global_step = 0
self.best_val_f1 = 0.0
# Initialize mixed precision training (Fixed version)
if self.config.fp16:
self.scaler = torch.cuda.amp.GradScaler()
else:
self.scaler = None
# Setup training components
self.setup_training()
# Setup output directory
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.output_dir = Path(config.output_dir) / self.timestamp
self.output_dir.mkdir(parents=True, exist_ok=True)
# Save config and initialize history
self.save_config()
self.history = {
'train_loss': [],
'val_loss': [],
'metrics': [],
'thresholds': []
}
def setup_logging(self):
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
def calculate_class_weights(self):
"""Calculate weights for imbalanced classes"""
pos_counts = self.train_dataset.labels.sum(dim=0)
neg_counts = len(self.train_dataset) - pos_counts
pos_weight = (neg_counts / pos_counts) * self.config.pos_weight_multiplier
return torch.clamp(pos_weight, min=1.0, max=50.0).to(self.device)
def setup_training(self):
"""Initialize training components with optimizations"""
# Create dataloaders
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.config.batch_size,
num_workers=4,
pin_memory=True
)
# Calculate class weights
pos_weight = self.calculate_class_weights()
# Setup loss function with class weights only
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=pos_weight
)
# Setup optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Setup scheduler
num_update_steps_per_epoch = len(self.train_loader) // self.config.gradient_accumulation_steps
num_training_steps = num_update_steps_per_epoch * self.config.num_epochs
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
# Initialize thresholds
self.label_thresholds = torch.ones(self.train_dataset.get_num_labels()).to(self.device) * 0.5
def save_config(self):
config_dict = {k: str(v) for k, v in vars(self.config).items()}
config_path = self.output_dir / 'config.json'
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)
def find_optimal_thresholds(self, val_outputs, val_labels):
"""Find optimal threshold for each label"""
outputs = torch.sigmoid(val_outputs).cpu().numpy()
labels = val_labels.cpu().numpy()
thresholds = []
for i in range(labels.shape[1]):
best_f1 = 0
best_threshold = 0.5
if labels[:, i].sum() > 0: # Only if we have positive samples
for threshold in np.arange(0.1, 0.9, 0.05):
preds = (outputs[:, i] > threshold).astype(int)
f1 = f1_score(labels[:, i], preds)
if f1 > best_f1:
best_f1 = f1
best_threshold = threshold
thresholds.append(best_threshold)
return torch.tensor(thresholds).to(self.device)
def calculate_detailed_metrics(self, all_labels, all_preds, all_probs=None):
"""Calculate detailed metrics for model evaluation"""
metrics = {}
# Basic metrics
metrics['micro'] = {
'precision': precision_score(all_labels, all_preds, average='micro'),
'recall': recall_score(all_labels, all_preds, average='micro'),
'f1': f1_score(all_labels, all_preds, average='micro')
}
metrics['macro'] = {
'precision': precision_score(all_labels, all_preds, average='macro'),
'recall': recall_score(all_labels, all_preds, average='macro'),
'f1': f1_score(all_labels, all_preds, average='macro')
}
metrics['weighted'] = {
'precision': precision_score(all_labels, all_preds, average='weighted'),
'recall': recall_score(all_labels, all_preds, average='weighted'),
'f1': f1_score(all_labels, all_preds, average='weighted')
}
# Per-class metrics
per_class_metrics = {}
precisions = precision_score(all_labels, all_preds, average=None)
recalls = recall_score(all_labels, all_preds, average=None)
f1s = f1_score(all_labels, all_preds, average=None)
for i in range(len(f1s)):
per_class_metrics[f'class_{i}'] = {
'precision': float(precisions[i]),
'recall': float(recalls[i]),
'f1': float(f1s[i]),
'support': int(all_labels[:, i].sum())
}
metrics['per_class'] = per_class_metrics
return metrics
def train_epoch(self):
"""Train for one epoch with optimizations"""
self.model.train()
total_loss = 0
self.optimizer.zero_grad()
pbar = tqdm(enumerate(self.train_loader),
total=len(self.train_loader),
desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}')
for step, batch in pbar:
batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
# Mixed precision training
with torch.cuda.amp.autocast(enabled=self.config.fp16):
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
features=batch['features']
)
loss = self.criterion(outputs, batch['labels'])
loss = loss / self.config.gradient_accumulation_steps
# Backward pass with scaler if fp16 is enabled
if self.config.fp16:
self.scaler.scale(loss).backward()
else:
loss.backward()
if (step + 1) % self.config.gradient_accumulation_steps == 0:
if self.config.fp16:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
if self.config.fp16:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
total_loss += loss.item() * self.config.gradient_accumulation_steps
avg_loss = total_loss / (step + 1)
pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
self.global_step += 1
if self.global_step % self.config.eval_steps == 0:
self.evaluate()
if step % 10 == 0:
torch.cuda.empty_cache()
del outputs
del loss
return total_loss / len(self.train_loader)
@torch.no_grad()
def evaluate(self):
"""Evaluate model with detailed metrics"""
self.model.eval()
total_loss = 0
all_outputs, all_labels = [], []
for batch in tqdm(self.val_loader, desc="Evaluating"):
batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
with autocast(enabled=self.config.fp16):
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
features=batch['features']
)
loss = self.criterion(outputs, batch['labels'])
total_loss += loss.item()
all_outputs.append(outputs.cpu())
all_labels.append(batch['labels'].cpu())
del outputs
del loss
torch.cuda.empty_cache()
all_outputs = torch.cat(all_outputs, dim=0)
all_labels = torch.cat(all_labels, dim=0)
if self.global_step % (self.config.eval_steps * 2) == 0:
self.label_thresholds = self.find_optimal_thresholds(all_outputs, all_labels)
all_probs = torch.sigmoid(all_outputs).numpy()
all_preds = (all_probs > self.label_thresholds.cpu().unsqueeze(0).numpy())
all_labels = all_labels.numpy()
metrics = self.calculate_detailed_metrics(all_labels, all_preds, all_probs)
metrics['loss'] = total_loss / len(self.val_loader)
self.logger.info(f"Step {self.global_step} - Validation metrics:")
self.logger.info(f"Loss: {metrics['loss']:.4f}")
self.logger.info(f"Micro F1: {metrics['micro']['f1']:.4f}")
self.logger.info(f"Macro F1: {metrics['macro']['f1']:.4f}")
if metrics['micro']['f1'] > self.best_val_f1:
self.best_val_f1 = metrics['micro']['f1']
self.save_model('best_model.pt', metrics)
return metrics
def save_model(self, filename: str, metrics: dict = None):
save_path = self.output_dir / filename
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'scaler_state_dict': self.scaler.state_dict(),
'epoch': self.current_epoch,
'global_step': self.global_step,
'best_val_f1': self.best_val_f1,
'metrics': metrics,
'thresholds': self.label_thresholds
}, save_path)
self.logger.info(f"Model saved to {save_path}")
def train(self):
"""Run complete training loop"""
self.logger.info("Starting training...")
try:
for epoch in range(self.config.num_epochs):
self.current_epoch = epoch
self.logger.info(f"Starting epoch {epoch + 1}/{self.config.num_epochs}")
train_loss = self.train_epoch()
self.history['train_loss'].append(train_loss)
val_metrics = self.evaluate()
self.history['metrics'].append(val_metrics)
self.history['thresholds'].append(self.label_thresholds.cpu().tolist())
self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics)
history_path = self.output_dir / 'history.json'
with open(history_path, 'w') as f:
json.dump(self.history, f, indent=4)
self.logger.info(f"Epoch {epoch + 1} completed. Train loss: {train_loss:.4f}")
self.logger.info("Training completed successfully!")
return self.history
except Exception as e:
self.logger.error(f"Training failed with error: {str(e)}")
raise