import os import logging import copy import torch import torch.nn as nn from torch.optim import Adam from src.pipeline import VanillaLSTM, VAE, Transformer # Class for model training and evaluation class Trainer: def __init__(self): self.logger = logging.getLogger(__name__) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = None, self.model = None self.model_type = None self.optimizer = None self.criterion = None self.train_loader = None self.val_loader = None self.test_loader = None self.n_epochs = None self.train_history = { 'train_loss': [], 'val_loss': [] } self.best_model = None self.best_val_loss = float('inf') def init_model(self, model, model_type): """ Initialize the model, optimizer and loss function :param model: The model architecture :param model_type: The type of the model """ self.logger.info("Initialize the model...") self.model = model.to(self.device) if model_type not in ["lstm", "vae", "transformer"]: raise ValueError("Model type not supported") self.model_type = model_type def config_train(self, batch_size=32, n_epochs=20, lr=0.001): """ Configure the training parameters :param batch_size: The batch size, default is 32 :param n_epochs: The number of epochs, default is 20 :param lr: The learning rate, default is 0.001 """ self.logger.info("Configure the training parameters...") self.batch_size = batch_size self.n_epochs = n_epochs self.optimizer = Adam(self.model.parameters(), lr=lr) self.criterion = nn.MSELoss() def train(self, train_loader, val_loader): """ Train the model :param train_loader: The training data loader :param val_loader: The validation data loader """ print("Training the model...") self.logger.info("Start training...") self.train_loader = train_loader self.val_loader = val_loader self.best_val_loss = float('inf') self.best_model = None for epoch in range(self.n_epochs): train_loss = self._train_epoch() val_loss = self._val_epoch() self.train_history['train_loss'].append(train_loss) self.train_history['val_loss'].append(val_loss) self.logger.info(f"Epoch {epoch + 1}/{self.n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") self.logger.info("Training completed!") print("Training completed!") return self.best_model, self.train_history def _train_epoch(self): """ Train the model for one epoch """ self.model.train() train_loss = 0 for seq in self.train_loader: self.optimizer.zero_grad() if self.model_type == "lstm": X_train = seq[:, :-1, :] # All timestamp except the last one y_train = seq[:, -1, :] # Final timestamp X_train = X_train.to(self.device) y_train = y_train.to(self.device) output = self.model(X_train) loss = self.criterion(output, y_train) elif self.model_type == "vae": X = seq.to(self.device) recon_X, mu, logvar = self.model(X) recon_loss = self.criterion(recon_X, X) kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0) loss = recon_loss + 0.2 * kl_div elif self.model_type == "transformer": X = seq.to(self.device) recon_X = self.model(X) loss = self.criterion(recon_X, X) else: raise ValueError("Model type not supported") loss.backward() self.optimizer.step() train_loss += loss.item() return train_loss / len(self.train_loader) def _val_epoch(self): """ Validate the model for one epoch """ self.model.eval() val_loss = 0 with torch.no_grad(): for seq in self.val_loader: if self.model_type == "lstm": X_val = seq[:, :-1, :] y_val = seq[:, -1, :] X_val = X_val.to(self.device) y_val = y_val.to(self.device) output = self.model(X_val) loss = self.criterion(output, y_val) elif self.model_type == "vae": X = seq.to(self.device) recon_X, mu, logvar = self.model(X) recon_loss = self.criterion(recon_X, X) kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0) loss = recon_loss + 0.2 * kl_div elif self.model_type == "transformer": X_val = seq.to(self.device) recon_X = self.model(X_val) loss = self.criterion(recon_X, X_val) else: raise ValueError("Model type not supported") val_loss += loss.item() if val_loss < self.best_val_loss: self.best_model = copy.deepcopy(self.model) self.best_val_loss = val_loss return val_loss / len(self.val_loader)