import numpy as np import torch import torch.nn as nn import torch.functional as F class LogisticActivation(nn.Module): """ Implementation of Generalized Sigmoid Applies the element-wise function: .. math:: \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-k(x-x_0))} Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input Parameters: - x0: The value of the sigmoid midpoint - k: The slope of the sigmoid - trainable Examples: >>> logAct = LogisticActivation(0, 5) >>> x = torch.randn(256) >>> x = logAct(x) """ def __init__(self, x0 = 0, k = 1, train=False): """ Initialization INPUT: - x0: The value of the sigmoid midpoint - k: The slope of the sigmoid - trainable - train: Whether to make k a trainable parameter x0 and k are initialized to 0,1 respectively Behaves the same as torch.sigmoid by default """ super(LogisticActivation,self).__init__() self.x0 = x0 self.k = nn.Parameter(torch.FloatTensor([float(k)])) self.k.requiresGrad = train def forward(self, x): """ Applies the function to the input elementwise """ o = torch.clamp(1 / (1 + torch.exp(-self.k * (x - self.x0))), min=0, max=1).squeeze() return o def clip(self): self.k.data.clamp_(min=0) class ModelInteraction(nn.Module): def __init__(self, embedding, contact, use_cuda, pool_size=9, theta_init=1, lambda_init = 0, gamma_init = 0, use_W=True): super(ModelInteraction, self).__init__() self.use_cuda = use_cuda self.use_W = use_W self.activation = LogisticActivation(x0=0.5, k = 20) self.embedding = embedding self.contact = contact if self.use_W: self.theta = nn.Parameter(torch.FloatTensor([theta_init])) self.lambda_ = nn.Parameter(torch.FloatTensor([lambda_init])) self.maxPool = nn.MaxPool2d(pool_size,padding=pool_size//2) self.gamma = nn.Parameter(torch.FloatTensor([gamma_init])) self.clip() def clip(self): self.contact.clip() if self.use_W: self.theta.data.clamp_(min=0, max=1) self.lambda_.data.clamp_(min=0) self.gamma.data.clamp_(min=0) def embed(self, x): if self.embedding is None: return x else: return self.embedding(x) def cpred(self, z0, z1): e0 = self.embed(z0) e1 = self.embed(z1) B = self.contact.cmap(e0, e1) C = self.contact.predict(B) return C def map_predict(self, z0, z1): C = self.cpred(z0, z1) if self.use_W: # Create contact weighting matrix N, M = C.shape[2:] x1 = torch.from_numpy(-1 * ((np.arange(N)+1 - ((N+1)/2)) / (-1 * ((N+1)/2)))**2).float() if self.use_cuda: x1 = x1.cuda() x1 = torch.exp(self.lambda_ * x1) x2 = torch.from_numpy(-1 * ((np.arange(M)+1 - ((M+1)/2)) / (-1 * ((M+1)/2)))**2).float() if self.use_cuda: x2 = x2.cuda() x2 = torch.exp(self.lambda_ * x2) W = x1.unsqueeze(1) * x2 W = (1 - self.theta) * W + self.theta yhat = C * W else: yhat = C yhat = self.maxPool(yhat) # Mean of contact predictions where p_ij > mu + gamma*sigma mu = torch.mean(yhat) sigma = torch.var(yhat) Q = torch.relu(yhat - mu - (self.gamma * sigma)) phat = torch.sum(Q) / (torch.sum(torch.sign(Q)) + 1) phat = self.activation(phat) return C, phat def predict(self, z0, z1): _, phat = self.map_predict(z0,z1) return phat