Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import copy | |
import torch | |
import torch.nn as nn | |
from torch.optim import Adam | |
from src.pipeline import VanillaLSTM, VAE, Transformer | |
# Class for model training and evaluation | |
class Trainer: | |
def __init__(self): | |
self.logger = logging.getLogger(__name__) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.batch_size = None, | |
self.model = None | |
self.model_type = None | |
self.optimizer = None | |
self.criterion = None | |
self.train_loader = None | |
self.val_loader = None | |
self.test_loader = None | |
self.n_epochs = None | |
self.train_history = { 'train_loss': [], 'val_loss': [] } | |
self.best_model = None | |
self.best_val_loss = float('inf') | |
def init_model(self, model, model_type): | |
""" | |
Initialize the model, optimizer and loss function | |
:param model: The model architecture | |
:param model_type: The type of the model | |
""" | |
self.logger.info("Initialize the model...") | |
self.model = model.to(self.device) | |
if model_type not in ["lstm", "vae", "transformer"]: | |
raise ValueError("Model type not supported") | |
self.model_type = model_type | |
def config_train(self, batch_size=32, n_epochs=20, lr=0.001): | |
""" | |
Configure the training parameters | |
:param batch_size: The batch size, default is 32 | |
:param n_epochs: The number of epochs, default is 20 | |
:param lr: The learning rate, default is 0.001 | |
""" | |
self.logger.info("Configure the training parameters...") | |
self.batch_size = batch_size | |
self.n_epochs = n_epochs | |
self.optimizer = Adam(self.model.parameters(), lr=lr) | |
self.criterion = nn.MSELoss() | |
def train(self, train_loader, val_loader): | |
""" | |
Train the model | |
:param train_loader: The training data loader | |
:param val_loader: The validation data loader | |
""" | |
print("Training the model...") | |
self.logger.info("Start training...") | |
self.train_loader = train_loader | |
self.val_loader = val_loader | |
self.best_val_loss = float('inf') | |
self.best_model = None | |
for epoch in range(self.n_epochs): | |
train_loss = self._train_epoch() | |
val_loss = self._val_epoch() | |
self.train_history['train_loss'].append(train_loss) | |
self.train_history['val_loss'].append(val_loss) | |
self.logger.info(f"Epoch {epoch + 1}/{self.n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") | |
self.logger.info("Training completed!") | |
print("Training completed!") | |
return self.best_model, self.train_history | |
def _train_epoch(self): | |
""" | |
Train the model for one epoch | |
""" | |
self.model.train() | |
train_loss = 0 | |
for seq in self.train_loader: | |
self.optimizer.zero_grad() | |
if self.model_type == "lstm": | |
X_train = seq[:, :-1, :] # All timestamp except the last one | |
y_train = seq[:, -1, :] # Final timestamp | |
X_train = X_train.to(self.device) | |
y_train = y_train.to(self.device) | |
output = self.model(X_train) | |
loss = self.criterion(output, y_train) | |
elif self.model_type == "vae": | |
X = seq.to(self.device) | |
recon_X, mu, logvar = self.model(X) | |
recon_loss = self.criterion(recon_X, X) | |
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0) | |
loss = recon_loss + 0.2 * kl_div | |
elif self.model_type == "transformer": | |
X = seq.to(self.device) | |
recon_X = self.model(X) | |
loss = self.criterion(recon_X, X) | |
else: | |
raise ValueError("Model type not supported") | |
loss.backward() | |
self.optimizer.step() | |
train_loss += loss.item() | |
return train_loss / len(self.train_loader) | |
def _val_epoch(self): | |
""" | |
Validate the model for one epoch | |
""" | |
self.model.eval() | |
val_loss = 0 | |
with torch.no_grad(): | |
for seq in self.val_loader: | |
if self.model_type == "lstm": | |
X_val = seq[:, :-1, :] | |
y_val = seq[:, -1, :] | |
X_val = X_val.to(self.device) | |
y_val = y_val.to(self.device) | |
output = self.model(X_val) | |
loss = self.criterion(output, y_val) | |
elif self.model_type == "vae": | |
X = seq.to(self.device) | |
recon_X, mu, logvar = self.model(X) | |
recon_loss = self.criterion(recon_X, X) | |
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0) | |
loss = recon_loss + 0.2 * kl_div | |
elif self.model_type == "transformer": | |
X_val = seq.to(self.device) | |
recon_X = self.model(X_val) | |
loss = self.criterion(recon_X, X_val) | |
else: | |
raise ValueError("Model type not supported") | |
val_loss += loss.item() | |
if val_loss < self.best_val_loss: | |
self.best_model = copy.deepcopy(self.model) | |
self.best_val_loss = val_loss | |
return val_loss / len(self.val_loader) | |