|
from typing import Tuple, Any, Union
|
|
import torch
|
|
from torch import nn
|
|
import numpy as np
|
|
|
|
|
|
class BaseModel(nn.Module):
|
|
def __init__(self):
|
|
super(BaseModel, self).__init__()
|
|
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
self.param_dict = {}
|
|
|
|
|
|
self.optimizer_dict = {
|
|
"adam": torch.optim.Adam,
|
|
"nadam": torch.optim.NAdam,
|
|
"adamax": torch.optim.Adamax,
|
|
"adamw": torch.optim.AdamW,
|
|
"sgd": torch.optim.SGD,
|
|
}
|
|
|
|
|
|
self.loss_dict = {
|
|
"mse": nn.MSELoss,
|
|
"l1": nn.L1Loss,
|
|
"smoothl1": nn.SmoothL1Loss,
|
|
"huber": nn.HuberLoss,
|
|
"cel": nn.CrossEntropyLoss,
|
|
"bcel": nn.BCELoss,
|
|
}
|
|
|
|
|
|
self.activation_dict = {
|
|
"relu": nn.ReLU,
|
|
"swish": nn.Hardswish,
|
|
"leaky_relu": nn.LeakyReLU,
|
|
"elu": nn.ELU,
|
|
"selu": nn.SELU,
|
|
}
|
|
|
|
|
|
self.tokenizer = None
|
|
|
|
|
|
self.history = {
|
|
"train_loss": [],
|
|
"on_diag_loss": [],
|
|
"off_diag_loss": [],
|
|
"validation_loss": [],
|
|
"learning_rate": [],
|
|
}
|
|
|
|
|
|
self.count = 0
|
|
|
|
def print_config(self) -> None:
|
|
print("[BT]: Current parameter config:")
|
|
print(self.param_dict)
|
|
|
|
def early_stopping(self, patience: int) -> bool:
|
|
|
|
if len(self.history["validation_loss"]) > patience:
|
|
best_loss = min(self.history["validation_loss"])
|
|
if self.history["validation_loss"][-1] > best_loss:
|
|
self.count += 1
|
|
else:
|
|
self.count = 0
|
|
if self.count >= patience:
|
|
if self.param_dict["verbose"] is True:
|
|
print("[BT]: Early stopping")
|
|
return True
|
|
return False
|
|
|