File size: 4,603 Bytes
8896a5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""
Contact model classes.
"""
import torch
import torch.nn as nn
import torch.functional as F
class FullyConnected(nn.Module):
"""
Performs part 1 of Contact Prediction Module. Takes embeddings from Projection module and produces broadcast tensor.
Input embeddings of dimension :math:`d` are combined into a :math:`2d` length MLP input :math:`z_{cat}`, where :math:`z_{cat} = [z_0 \\ominus z_1 | z_0 \\odot z_1]`
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100]
:type embed_dim: int
:param hidden_dim: Hidden dimension :math:`h` [default: 50]
:type hidden_dim: int
:param activation: Activation function for broadcast tensor [default: torch.nn.ReLU()]
:type activation: torch.nn.Module
"""
def __init__(self, embed_dim, hidden_dim, activation=nn.ReLU()):
super(FullyConnected, self).__init__()
self.D = embed_dim
self.H = hidden_dim
self.conv = nn.Conv2d(2 * self.D, self.H, 1)
self.batchnorm = nn.BatchNorm2d(self.H)
self.activation = activation
def forward(self, z0, z1):
"""
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
:type z0: torch.Tensor
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
:type z1: torch.Tensor
:return: Predicted broadcast tensor :math:`(b \\times N \\times M \\times h)`
:rtype: torch.Tensor
"""
# z0 is (b,N,d), z1 is (b,M,d)
z0 = z0.transpose(1, 2)
z1 = z1.transpose(1, 2)
# z0 is (b,d,N), z1 is (b,d,M)
z_dif = torch.abs(z0.unsqueeze(3) - z1.unsqueeze(2))
z_mul = z0.unsqueeze(3) * z1.unsqueeze(2)
z_cat = torch.cat([z_dif, z_mul], 1)
b = self.conv(z_cat)
b = self.activation(b)
b = self.batchnorm(b)
return b
class ContactCNN(nn.Module):
"""
Residue Contact Prediction Module. Takes embeddings from Projection module and produces contact map, output of Contact module.
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100]
:type embed_dim: int
:param hidden_dim: Hidden dimension :math:`h` [default: 50]
:type hidden_dim: int
:param width: Width of convolutional filter :math:`2w+1` [default: 7]
:type width: int
:param activation: Activation function for final contact map [default: torch.nn.Sigmoid()]
:type activation: torch.nn.Module
"""
def __init__(self, embed_dim=100, hidden_dim=50, width=7, activation=nn.Sigmoid()):
super(ContactCNN, self).__init__()
self.hidden = FullyConnected(embed_dim, hidden_dim)
self.conv = nn.Conv2d(hidden_dim, 1, width, padding=width // 2)
self.batchnorm = nn.BatchNorm2d(1)
self.activation = activation
self.clip()
def clip(self):
"""
Force the convolutional layer to be transpose invariant.
:meta private:
"""
w = self.conv.weight
self.conv.weight.data[:] = 0.5 * (w + w.transpose(2, 3))
def forward(self, z0, z1):
"""
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
:type z0: torch.Tensor
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
:type z1: torch.Tensor
:return: Predicted contact map :math:`(b \\times N \\times M)`
:rtype: torch.Tensor
"""
B = self.broadcast(z0, z1)
return self.predict(B)
def broadcast(self, z0, z1):
"""
Calls `dscript.models.contact.FullyConnected <#module-dscript.models.contact.FullyConnected>`_.
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
:type z0: torch.Tensor
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
:type z1: torch.Tensor
:return: Predicted contact broadcast tensor :math:`(b \\times N \\times M \\times h)`
:rtype: torch.Tensor
"""
B = self.hidden(z0, z1)
return B
def predict(self, B):
"""
Predict contact map from broadcast tensor.
:param B: Predicted contact broadcast :math:`(b \\times N \\times M \\times h)`
:type B: torch.Tensor
:return: Predicted contact map :math:`(b \\times N \\times M)`
:rtype: torch.Tensor
"""
C = self.conv(B)
C = self.batchnorm(C)
C = self.activation(C)
return C
|