tgd1115's picture
Upload 12 files
8474315 verified
import os
import logging
import copy
import torch
import torch.nn as nn
from torch.optim import Adam
from src.pipeline import VanillaLSTM, VAE, Transformer
# Class for model training and evaluation
class Trainer:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = None,
self.model = None
self.model_type = None
self.optimizer = None
self.criterion = None
self.train_loader = None
self.val_loader = None
self.test_loader = None
self.n_epochs = None
self.train_history = { 'train_loss': [], 'val_loss': [] }
self.best_model = None
self.best_val_loss = float('inf')
def init_model(self, model, model_type):
"""
Initialize the model, optimizer and loss function
:param model: The model architecture
:param model_type: The type of the model
"""
self.logger.info("Initialize the model...")
self.model = model.to(self.device)
if model_type not in ["lstm", "vae", "transformer"]:
raise ValueError("Model type not supported")
self.model_type = model_type
def config_train(self, batch_size=32, n_epochs=20, lr=0.001):
"""
Configure the training parameters
:param batch_size: The batch size, default is 32
:param n_epochs: The number of epochs, default is 20
:param lr: The learning rate, default is 0.001
"""
self.logger.info("Configure the training parameters...")
self.batch_size = batch_size
self.n_epochs = n_epochs
self.optimizer = Adam(self.model.parameters(), lr=lr)
self.criterion = nn.MSELoss()
def train(self, train_loader, val_loader):
"""
Train the model
:param train_loader: The training data loader
:param val_loader: The validation data loader
"""
print("Training the model...")
self.logger.info("Start training...")
self.train_loader = train_loader
self.val_loader = val_loader
self.best_val_loss = float('inf')
self.best_model = None
for epoch in range(self.n_epochs):
train_loss = self._train_epoch()
val_loss = self._val_epoch()
self.train_history['train_loss'].append(train_loss)
self.train_history['val_loss'].append(val_loss)
self.logger.info(f"Epoch {epoch + 1}/{self.n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
self.logger.info("Training completed!")
print("Training completed!")
return self.best_model, self.train_history
def _train_epoch(self):
"""
Train the model for one epoch
"""
self.model.train()
train_loss = 0
for seq in self.train_loader:
self.optimizer.zero_grad()
if self.model_type == "lstm":
X_train = seq[:, :-1, :] # All timestamp except the last one
y_train = seq[:, -1, :] # Final timestamp
X_train = X_train.to(self.device)
y_train = y_train.to(self.device)
output = self.model(X_train)
loss = self.criterion(output, y_train)
elif self.model_type == "vae":
X = seq.to(self.device)
recon_X, mu, logvar = self.model(X)
recon_loss = self.criterion(recon_X, X)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0)
loss = recon_loss + 0.2 * kl_div
elif self.model_type == "transformer":
X = seq.to(self.device)
recon_X = self.model(X)
loss = self.criterion(recon_X, X)
else:
raise ValueError("Model type not supported")
loss.backward()
self.optimizer.step()
train_loss += loss.item()
return train_loss / len(self.train_loader)
def _val_epoch(self):
"""
Validate the model for one epoch
"""
self.model.eval()
val_loss = 0
with torch.no_grad():
for seq in self.val_loader:
if self.model_type == "lstm":
X_val = seq[:, :-1, :]
y_val = seq[:, -1, :]
X_val = X_val.to(self.device)
y_val = y_val.to(self.device)
output = self.model(X_val)
loss = self.criterion(output, y_val)
elif self.model_type == "vae":
X = seq.to(self.device)
recon_X, mu, logvar = self.model(X)
recon_loss = self.criterion(recon_X, X)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0)
loss = recon_loss + 0.2 * kl_div
elif self.model_type == "transformer":
X_val = seq.to(self.device)
recon_X = self.model(X_val)
loss = self.criterion(recon_X, X_val)
else:
raise ValueError("Model type not supported")
val_loss += loss.item()
if val_loss < self.best_val_loss:
self.best_model = copy.deepcopy(self.model)
self.best_val_loss = val_loss
return val_loss / len(self.val_loader)