File size: 3,722 Bytes
8896a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Input: C = NxMxH embedding contact matrix
# Output: S = MxN contact prediction matrix

import torch
import torch.nn as nn
import torch.functional as F

# Choices for f(Z,Z')
class L1(nn.Module):
    # H = 1
    def forward(self, z, zpr):
        return torch.sum(torch.abs(z.unsqueeze(1) - zpr), -1)

class L2(nn.Module):
    # H = 1
    def forward(self, z, zpr):
        return torch.sqrt(torch.sum((z.unsqueeze(1) - zpr)**2, -1))

class FullyConnected(nn.Module):
    # H = contact_dim
    # c_i,j = Wh where h = [z0_i | z1_j]
    def __init__(self, embed_dim, hidden_dim, activation=nn.ReLU()):
        super(FullyConnected, self).__init__()
        
        self.D = embed_dim
        self.H = hidden_dim
        
        self.conv = nn.Conv2d(2*self.D, self.H, 1)
        torch.nn.init.normal_(self.conv.weight)
        torch.nn.init.uniform_(self.conv.bias, 0, 0)
        
        self.batchnorm = nn.BatchNorm2d(self.H)
        self.activation = activation
        
    def forward(self, z0, z1):
        z0 = z0.transpose(1,2)
        z1 = z1.transpose(1,2)
        
        z_dif = torch.abs(z0.unsqueeze(3) - z1.unsqueeze(2))
        z_mul = z0.unsqueeze(3) * z1.unsqueeze(2)
        z_cat = torch.cat([z_dif, z_mul], 1)
        
        c = self.conv(z_cat)
        c = self.activation(c)
        c = self.batchnorm(c)
        
        return c
    
# Contact Prediction Model
class ContactCNN(nn.Module):
    def __init__(self, embed_dim, hidden_dim=50, width=7, output_dim=1, activation=nn.Sigmoid()):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.hidden = FullyConnected(self.embed_dim, self.hidden_dim)
        
        self.conv = nn.Conv2d(self.hidden_dim, self.output_dim, width, padding=width//2)
        torch.nn.init.normal_(self.conv.weight)
        torch.nn.init.uniform_(self.conv.bias, 0, 0)
        
        self.batchnorm = nn.BatchNorm2d(self.output_dim)
        self.activation = activation
        self.clip()

    def clip(self):
        # force the conv layer to be transpose invariant
        w = self.conv.weight
        self.conv.weight.data[:] = 0.5*(w + w.transpose(2,3))

    def forward(self, z0, z1):
        B = self.broadcast(z0, z1)
        C = self.predict(B)
        return C

    def broadcast(self, z0, z1):
        B = self.hidden(z0, z1)
        return B
    
    def predict(self, B):
        C = self.conv(B)
        C = self.batchnorm(C)
        C = self.activation(C)
        return C
    
class ContactCNN_v2(nn.Module):
    def __init__(self, embed_dim, hidden_dim=50, width=7, output_dim=1, activation=nn.Sigmoid()):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.hidden = FullyConnected(self.embed_dim, self.hidden_dim)
        
        self.conv = nn.Conv2d(self.hidden_dim, self.output_dim, width, padding=width//2)
        torch.nn.init.normal_(self.conv.weight)
        torch.nn.init.uniform_(self.conv.bias, 0, 0)
        
        self.batchnorm = nn.BatchNorm2d(self.output_dim)
        self.activation = activation
        self.clip()

    def clip(self):
        # force the conv layer to be transpose invariant
        w = self.conv.weight
        self.conv.weight.data[:] = 0.5*(w + w.transpose(2,3))

    def forward(self, z0, z1):
        B = self.broadcast(z0, z1)
        C = self.predict(B)
        return C

    def broadcast(self, z0, z1):
        B = self.hidden(z0, z1)
        return B
    
    def predict(self, B):
        C = self.conv(B)
        C = self.batchnorm(C)
        C = self.activation(C)
        return C