Spaces:
Sleeping
Sleeping
File size: 2,754 Bytes
bacf16b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from typing import List, Callable, Any, Tuple
import numpy as np
import torch
from sympy import pprint
from torch import nn, Tensor
from .collate import default_collate
class EarlyStop:
def __init__(self, patience: int, delta: float):
self.patience: int = patience
self.delta: float = delta
self.counter: int = 0
self.best_loss: float = np.Inf
self.stop: bool = False
def __call__(self, loss: float, model: nn.Module, path: str) -> None:
if loss < self.best_loss:
self.best_loss = loss
self.counter = 0
torch.save(model.state_dict(), path)
elif loss > self.best_loss + self.delta:
self.counter = self.counter + 1
if self.counter >= self.patience:
self.stop = True
class ExpLikeliLoss(nn.Module):
def __init__(self, num_samples: int = 100):
super(ExpLikeliLoss, self).__init__()
self.num_samples: int = num_samples
def forward(self, pred: Tensor, true: Tensor, logvar: Tensor) -> Tensor:
b, l, d = pred.size(0), pred.size(1), pred.size(2)
true = true.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
pred = pred.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
logvar = logvar.reshape(-1, self.num_samples)
loss = torch.mean((-1) * torch.logsumexp((-l / 2) * logvar + (-1 / (2 * torch.exp(logvar))) * torch.sum((true - pred) ** 2, dim=1), dim=1))
return loss
def modify_collate(num_samples: int) -> Callable[[List[Any]], Any]:
def wrapper(batch: List[Any]) -> Any:
batch_rep = [sample for sample in batch for _ in range(num_samples)]
result = default_collate(batch_rep)
return result
return wrapper
def adjust_learning_rate(model_optim: torch.optim.Optimizer, epoch: int, lr: float) -> None:
lr = lr * (0.5 ** epoch)
print("Learning rate halving...")
print(f"New lr: {lr:.7f}")
for param_group in model_optim.param_groups:
param_group['lr'] = lr
def process_batch(
subj_id: Tensor,
batch_x: Tensor,
batch_y: Tensor,
batch_x_mark: Tensor,
batch_y_mark: Tensor,
len_pred: int,
len_label: int,
model: nn.Module,
device: torch.device
) -> Tuple[Tensor, Tensor, Tensor]:
subj_id = subj_id.long().to(device)
batch_x = batch_x.float().to(device)
batch_y = batch_y.float()
batch_x_mark = batch_x_mark.float().to(device)
batch_y_mark = batch_y_mark.float().to(device)
true = batch_y[:, -len_pred:, :].to(device)
dec_inp = torch.zeros([batch_y.shape[0], len_pred, batch_y.shape[-1]], dtype=torch.float, device=device)
dec_inp = torch.cat([batch_y[:, :len_label, :].to(device), dec_inp], dim=1)
pred, logvar = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark)
return pred, true, logvar |