|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.functional as F |
|
|
|
|
|
class L1(nn.Module): |
|
|
|
def forward(self, z, zpr): |
|
return torch.sum(torch.abs(z.unsqueeze(1) - zpr), -1) |
|
|
|
class L2(nn.Module): |
|
|
|
def forward(self, z, zpr): |
|
return torch.sqrt(torch.sum((z.unsqueeze(1) - zpr)**2, -1)) |
|
|
|
class FullyConnected(nn.Module): |
|
|
|
|
|
def __init__(self, embed_dim, hidden_dim, activation=nn.ReLU()): |
|
super(FullyConnected, self).__init__() |
|
|
|
self.D = embed_dim |
|
self.H = hidden_dim |
|
|
|
self.conv = nn.Conv2d(2*self.D, self.H, 1) |
|
torch.nn.init.normal_(self.conv.weight) |
|
torch.nn.init.uniform_(self.conv.bias, 0, 0) |
|
|
|
self.batchnorm = nn.BatchNorm2d(self.H) |
|
self.activation = activation |
|
|
|
def forward(self, z0, z1): |
|
z0 = z0.transpose(1,2) |
|
z1 = z1.transpose(1,2) |
|
|
|
z_dif = torch.abs(z0.unsqueeze(3) - z1.unsqueeze(2)) |
|
z_mul = z0.unsqueeze(3) * z1.unsqueeze(2) |
|
z_cat = torch.cat([z_dif, z_mul], 1) |
|
|
|
c = self.conv(z_cat) |
|
c = self.activation(c) |
|
c = self.batchnorm(c) |
|
|
|
return c |
|
|
|
|
|
class ContactCNN(nn.Module): |
|
def __init__(self, embed_dim, hidden_dim=50, width=7, output_dim=1, activation=nn.Sigmoid()): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.hidden_dim = hidden_dim |
|
self.output_dim = output_dim |
|
|
|
self.hidden = FullyConnected(self.embed_dim, self.hidden_dim) |
|
|
|
self.conv = nn.Conv2d(self.hidden_dim, self.output_dim, width, padding=width//2) |
|
torch.nn.init.normal_(self.conv.weight) |
|
torch.nn.init.uniform_(self.conv.bias, 0, 0) |
|
|
|
self.batchnorm = nn.BatchNorm2d(self.output_dim) |
|
self.activation = activation |
|
self.clip() |
|
|
|
def clip(self): |
|
|
|
w = self.conv.weight |
|
self.conv.weight.data[:] = 0.5*(w + w.transpose(2,3)) |
|
|
|
def forward(self, z0, z1): |
|
B = self.broadcast(z0, z1) |
|
C = self.predict(B) |
|
return C |
|
|
|
def broadcast(self, z0, z1): |
|
B = self.hidden(z0, z1) |
|
return B |
|
|
|
def predict(self, B): |
|
C = self.conv(B) |
|
C = self.batchnorm(C) |
|
C = self.activation(C) |
|
return C |
|
|
|
class ContactCNN_v2(nn.Module): |
|
def __init__(self, embed_dim, hidden_dim=50, width=7, output_dim=1, activation=nn.Sigmoid()): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.hidden_dim = hidden_dim |
|
self.output_dim = output_dim |
|
|
|
self.hidden = FullyConnected(self.embed_dim, self.hidden_dim) |
|
|
|
self.conv = nn.Conv2d(self.hidden_dim, self.output_dim, width, padding=width//2) |
|
torch.nn.init.normal_(self.conv.weight) |
|
torch.nn.init.uniform_(self.conv.bias, 0, 0) |
|
|
|
self.batchnorm = nn.BatchNorm2d(self.output_dim) |
|
self.activation = activation |
|
self.clip() |
|
|
|
def clip(self): |
|
|
|
w = self.conv.weight |
|
self.conv.weight.data[:] = 0.5*(w + w.transpose(2,3)) |
|
|
|
def forward(self, z0, z1): |
|
B = self.broadcast(z0, z1) |
|
C = self.predict(B) |
|
return C |
|
|
|
def broadcast(self, z0, z1): |
|
B = self.hidden(z0, z1) |
|
return B |
|
|
|
def predict(self, B): |
|
C = self.conv(B) |
|
C = self.batchnorm(C) |
|
C = self.activation(C) |
|
return C |