import logging import numpy as np import torch from torch.nn.parameter import Parameter _lim_val = 36.0 eps = np.finfo(np.float64).resolution def logexp(t): return eps + torch.where(t > _lim_val, t, torch.log(torch.exp(torch.clamp(t, -_lim_val, _lim_val)) + 1.)) def inv_logexp(t): return np.where(t>_lim_val, t, np.log(np.exp(t + eps) - 1)) class Sech(torch.autograd.Function): """Implementation of sech(x) = 2 / (e^x + e^(-x)).""" @staticmethod def forward(ctx, x): cosh = torch.cosh(x) sech = 1. / cosh sech = torch.where(torch.isinf(cosh), torch.zeros_like(sech), sech) ctx.save_for_backward(x, sech) return sech @staticmethod def backward(ctx, grad_output): x, sech = ctx.saved_tensors return -sech * torch.tanh(x) * grad_output class TanhSingleWarpingTerm(torch.nn.Module): """A tanh mapping with scaling and translation. Maps y to a * tanh(b * (y + c)), where a, b, c are positive scalars. The parameters are pre_a, pre_b, and c, initialized uniformly in [-1, 1]. """ def __init__(self): super(TanhSingleWarpingTerm, self).__init__() self.pre_a = Parameter(torch.Tensor(1)) self.pre_b = Parameter(torch.Tensor(1)) self.c = Parameter(torch.Tensor(1)) def reset_parameters(self): # Initialize according to warpedLMM code. torch.nn.init.normal_(self.pre_a) torch.nn.init.normal_(self.pre_b) with torch.no_grad(): self.pre_a.abs_() self.pre_b.abs_() torch.nn.init.uniform_(self.c, -0.5, 0.5) def set_parameters(self, a, b, c): self.pre_a.data.fill_(inv_logexp(a).item()) self.pre_b.data.fill_(inv_logexp(b).item()) self.c.data.fill_(c) def get_parameters(self): return (logexp(self.pre_a).detach().item(), logexp(self.pre_b).detach().item(), self.c.detach().item()) def forward(self, y): a = logexp(self.pre_a) b = logexp(self.pre_b) return a * torch.tanh(b * (y + self.c)) def jacobian(self, y): """Returns df/dy evaluated at the y. df/dy = a * b * sech^2 (b * (y + c)). """ a = logexp(self.pre_a) b = logexp(self.pre_b) sech = Sech.apply return a * b * (sech(b * (y + self.c)) ** 2) class TanhWarpingLayer(torch.nn.Module): """ A warping layer combining linear and tanh mappings. Maps y to d * y + a_1 * tanh(b_1 * (y + c_1)) + ... + a_n * tanh( b_n * (y + c_n)) where all d, a_i, b_i are positive scalars. """ def __init__(self, num_warping_terms): super(TanhWarpingLayer, self).__init__() self.num_warping_terms = num_warping_terms warping_terms = [] for i in range(num_warping_terms): warping_terms.append(TanhSingleWarpingTerm()) self.warping_terms = torch.nn.ModuleList(warping_terms) self.pre_d = Parameter(torch.Tensor(1)) self.reset_parameters() def reset_parameters(self): torch.nn.init.normal_(self.pre_d) with torch.no_grad(): self.pre_d.abs_() for t in self.warping_terms: t.reset_parameters() def set_parameters(self, a, b, c, d): """Sets parameters of the warping layer. Args: a, b, c are arrays of length num_warping_terms, d is a scalar. """ if len(a) != self.num_warping_terms or len(b) != self.num_warping_terms or ( len(c) != self.num_warping_terms): raise ValueError("Expected %d warping terms", self.num_warping_terms) self.pre_d.data.fill_(inv_logexp(d).item()) for i, t in enumerate(self.warping_terms): t.set_parameters(a[i], b[i], c[i]) def get_parameters(self): """Returns parameters of the warping layer. Returns: Warping parameters a,b,c,d. """ d = logexp(self.pre_d).detach().item() a = np.zeros(self.num_warping_terms) b = np.zeros(self.num_warping_terms) c = np.zeros(self.num_warping_terms) for i, t in enumerate(self.warping_terms): a[i], b[i], c[i] = t.get_parameters() return a, b, c, d def write_parameters(self): """Writes parameters to logging.debug.""" a, b, c, d = self.get_parameters() logging.debug(f"a: {a}") logging.debug(f"b: {b}") logging.debug(f"c: {c}") logging.debug(f"d: {d}") def forward(self, y): s = logexp(self.pre_d) * y for warping_term in self.warping_terms: s += warping_term.forward(y) return s def extra_repr(self): return "num_warping_terms={}".format(self.num_warping_terms) def jacobian(self, y): """Returns df/dy evaluated at the y.""" jcb = logexp(self.pre_d) * torch.ones_like(y) for warping_term in self.warping_terms: jcb += warping_term.jacobian(y) return jcb def numpy_fn(self): a, b, c, d = self.get_parameters() def fn(x): s = d * x for i in range(self.num_warping_terms): s += a[i] * np.tanh(b[i] * (x + c[i])) return s return fn