Spaces:
Sleeping
Sleeping
from typing import List, Callable, Any, Tuple | |
import numpy as np | |
import torch | |
from sympy import pprint | |
from torch import nn, Tensor | |
from .collate import default_collate | |
class EarlyStop: | |
def __init__(self, patience: int, delta: float): | |
self.patience: int = patience | |
self.delta: float = delta | |
self.counter: int = 0 | |
self.best_loss: float = np.Inf | |
self.stop: bool = False | |
def __call__(self, loss: float, model: nn.Module, path: str) -> None: | |
if loss < self.best_loss: | |
self.best_loss = loss | |
self.counter = 0 | |
torch.save(model.state_dict(), path) | |
elif loss > self.best_loss + self.delta: | |
self.counter = self.counter + 1 | |
if self.counter >= self.patience: | |
self.stop = True | |
class ExpLikeliLoss(nn.Module): | |
def __init__(self, num_samples: int = 100): | |
super(ExpLikeliLoss, self).__init__() | |
self.num_samples: int = num_samples | |
def forward(self, pred: Tensor, true: Tensor, logvar: Tensor) -> Tensor: | |
b, l, d = pred.size(0), pred.size(1), pred.size(2) | |
true = true.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1) | |
pred = pred.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1) | |
logvar = logvar.reshape(-1, self.num_samples) | |
loss = torch.mean((-1) * torch.logsumexp((-l / 2) * logvar + (-1 / (2 * torch.exp(logvar))) * torch.sum((true - pred) ** 2, dim=1), dim=1)) | |
return loss | |
def modify_collate(num_samples: int) -> Callable[[List[Any]], Any]: | |
def wrapper(batch: List[Any]) -> Any: | |
batch_rep = [sample for sample in batch for _ in range(num_samples)] | |
result = default_collate(batch_rep) | |
return result | |
return wrapper | |
def adjust_learning_rate(model_optim: torch.optim.Optimizer, epoch: int, lr: float) -> None: | |
lr = lr * (0.5 ** epoch) | |
print("Learning rate halving...") | |
print(f"New lr: {lr:.7f}") | |
for param_group in model_optim.param_groups: | |
param_group['lr'] = lr | |
def process_batch( | |
subj_id: Tensor, | |
batch_x: Tensor, | |
batch_y: Tensor, | |
batch_x_mark: Tensor, | |
batch_y_mark: Tensor, | |
len_pred: int, | |
len_label: int, | |
model: nn.Module, | |
device: torch.device | |
) -> Tuple[Tensor, Tensor, Tensor]: | |
subj_id = subj_id.long().to(device) | |
batch_x = batch_x.float().to(device) | |
batch_y = batch_y.float() | |
batch_x_mark = batch_x_mark.float().to(device) | |
batch_y_mark = batch_y_mark.float().to(device) | |
true = batch_y[:, -len_pred:, :].to(device) | |
dec_inp = torch.zeros([batch_y.shape[0], len_pred, batch_y.shape[-1]], dtype=torch.float, device=device) | |
dec_inp = torch.cat([batch_y[:, :len_label, :].to(device), dec_inp], dim=1) | |
pred, logvar = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark) | |
return pred, true, logvar |