Spaces:
Sleeping
Sleeping
File size: 5,730 Bytes
8474315 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import logging
import copy
import torch
import torch.nn as nn
from torch.optim import Adam
from src.pipeline import VanillaLSTM, VAE, Transformer
# Class for model training and evaluation
class Trainer:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = None,
self.model = None
self.model_type = None
self.optimizer = None
self.criterion = None
self.train_loader = None
self.val_loader = None
self.test_loader = None
self.n_epochs = None
self.train_history = { 'train_loss': [], 'val_loss': [] }
self.best_model = None
self.best_val_loss = float('inf')
def init_model(self, model, model_type):
"""
Initialize the model, optimizer and loss function
:param model: The model architecture
:param model_type: The type of the model
"""
self.logger.info("Initialize the model...")
self.model = model.to(self.device)
if model_type not in ["lstm", "vae", "transformer"]:
raise ValueError("Model type not supported")
self.model_type = model_type
def config_train(self, batch_size=32, n_epochs=20, lr=0.001):
"""
Configure the training parameters
:param batch_size: The batch size, default is 32
:param n_epochs: The number of epochs, default is 20
:param lr: The learning rate, default is 0.001
"""
self.logger.info("Configure the training parameters...")
self.batch_size = batch_size
self.n_epochs = n_epochs
self.optimizer = Adam(self.model.parameters(), lr=lr)
self.criterion = nn.MSELoss()
def train(self, train_loader, val_loader):
"""
Train the model
:param train_loader: The training data loader
:param val_loader: The validation data loader
"""
print("Training the model...")
self.logger.info("Start training...")
self.train_loader = train_loader
self.val_loader = val_loader
self.best_val_loss = float('inf')
self.best_model = None
for epoch in range(self.n_epochs):
train_loss = self._train_epoch()
val_loss = self._val_epoch()
self.train_history['train_loss'].append(train_loss)
self.train_history['val_loss'].append(val_loss)
self.logger.info(f"Epoch {epoch + 1}/{self.n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
self.logger.info("Training completed!")
print("Training completed!")
return self.best_model, self.train_history
def _train_epoch(self):
"""
Train the model for one epoch
"""
self.model.train()
train_loss = 0
for seq in self.train_loader:
self.optimizer.zero_grad()
if self.model_type == "lstm":
X_train = seq[:, :-1, :] # All timestamp except the last one
y_train = seq[:, -1, :] # Final timestamp
X_train = X_train.to(self.device)
y_train = y_train.to(self.device)
output = self.model(X_train)
loss = self.criterion(output, y_train)
elif self.model_type == "vae":
X = seq.to(self.device)
recon_X, mu, logvar = self.model(X)
recon_loss = self.criterion(recon_X, X)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0)
loss = recon_loss + 0.2 * kl_div
elif self.model_type == "transformer":
X = seq.to(self.device)
recon_X = self.model(X)
loss = self.criterion(recon_X, X)
else:
raise ValueError("Model type not supported")
loss.backward()
self.optimizer.step()
train_loss += loss.item()
return train_loss / len(self.train_loader)
def _val_epoch(self):
"""
Validate the model for one epoch
"""
self.model.eval()
val_loss = 0
with torch.no_grad():
for seq in self.val_loader:
if self.model_type == "lstm":
X_val = seq[:, :-1, :]
y_val = seq[:, -1, :]
X_val = X_val.to(self.device)
y_val = y_val.to(self.device)
output = self.model(X_val)
loss = self.criterion(output, y_val)
elif self.model_type == "vae":
X = seq.to(self.device)
recon_X, mu, logvar = self.model(X)
recon_loss = self.criterion(recon_X, X)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / X.size(0)
loss = recon_loss + 0.2 * kl_div
elif self.model_type == "transformer":
X_val = seq.to(self.device)
recon_X = self.model(X_val)
loss = self.criterion(recon_X, X_val)
else:
raise ValueError("Model type not supported")
val_loss += loss.item()
if val_loss < self.best_val_loss:
self.best_model = copy.deepcopy(self.model)
self.best_val_loss = val_loss
return val_loss / len(self.val_loader)
|