File size: 2,754 Bytes
bacf16b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import List, Callable, Any, Tuple

import numpy as np
import torch
from sympy import pprint
from torch import nn, Tensor

from .collate import default_collate

class EarlyStop:
  def __init__(self, patience: int, delta: float):
    self.patience: int = patience
    self.delta: float = delta
    self.counter: int = 0
    self.best_loss: float = np.Inf
    self.stop: bool = False

  def __call__(self, loss: float, model: nn.Module, path: str) -> None:
    if loss < self.best_loss:
      self.best_loss = loss
      self.counter = 0
      torch.save(model.state_dict(), path)
    elif loss > self.best_loss + self.delta:
      self.counter = self.counter + 1
      if self.counter >= self.patience:
        self.stop = True

class ExpLikeliLoss(nn.Module):
  def __init__(self, num_samples: int = 100):
    super(ExpLikeliLoss, self).__init__()
    self.num_samples: int = num_samples

  def forward(self, pred: Tensor, true: Tensor, logvar: Tensor) -> Tensor:
    b, l, d = pred.size(0), pred.size(1), pred.size(2)
    true = true.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
    pred = pred.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
    logvar = logvar.reshape(-1, self.num_samples)

    loss = torch.mean((-1) * torch.logsumexp((-l / 2) * logvar + (-1 / (2 * torch.exp(logvar))) * torch.sum((true - pred) ** 2, dim=1), dim=1))
    return loss

def modify_collate(num_samples: int) -> Callable[[List[Any]], Any]:
  def wrapper(batch: List[Any]) -> Any:
    batch_rep = [sample for sample in batch for _ in range(num_samples)]
    result = default_collate(batch_rep)
    return result
  return wrapper

def adjust_learning_rate(model_optim: torch.optim.Optimizer, epoch: int, lr: float) -> None:
  lr = lr * (0.5 ** epoch)
  print("Learning rate halving...")
  print(f"New lr: {lr:.7f}")
  for param_group in model_optim.param_groups:
    param_group['lr'] = lr

def process_batch(
        subj_id: Tensor,
        batch_x: Tensor,
        batch_y: Tensor,
        batch_x_mark: Tensor,
        batch_y_mark: Tensor,
        len_pred: int,
        len_label: int,
        model: nn.Module,
        device: torch.device
) -> Tuple[Tensor, Tensor, Tensor]:
  subj_id = subj_id.long().to(device)
  batch_x = batch_x.float().to(device)
  batch_y = batch_y.float()
  batch_x_mark = batch_x_mark.float().to(device)
  batch_y_mark = batch_y_mark.float().to(device)

  true = batch_y[:, -len_pred:, :].to(device)

  dec_inp = torch.zeros([batch_y.shape[0], len_pred, batch_y.shape[-1]], dtype=torch.float, device=device)
  dec_inp = torch.cat([batch_y[:, :len_label, :].to(device), dec_inp], dim=1)

  pred, logvar = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark)

  return pred, true, logvar