Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
# some predefined parameters | |
elem_list = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', | |
'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', | |
'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', | |
'Sm', 'Os', 'Ir', 'Ce', 'Gd', 'Ga', 'Cs', 'unknown'] | |
atom_fdim = len(elem_list) + 6 + 6 + 6 + 1 | |
bond_fdim = 6 | |
max_nb = 6 | |
class MONN(nn.Module): | |
# init_A, init_B, init_W = loading_emb(measure) | |
# net = Net(init_A, init_B, init_W, params) | |
def __init__(self, init_atom_features, init_bond_features, init_word_features, params): | |
super().__init__() | |
self.init_atom_features = init_atom_features | |
self.init_bond_features = init_bond_features | |
self.init_word_features = init_word_features | |
"""hyper part""" | |
GNN_depth, inner_CNN_depth, DMA_depth, k_head, kernel_size, hidden_size1, hidden_size2 = params | |
self.GNN_depth = GNN_depth | |
self.inner_CNN_depth = inner_CNN_depth | |
self.DMA_depth = DMA_depth | |
self.k_head = k_head | |
self.kernel_size = kernel_size | |
self.hidden_size1 = hidden_size1 | |
self.hidden_size2 = hidden_size2 | |
"""GraphConv Module""" | |
self.vertex_embedding = nn.Linear(atom_fdim, | |
self.hidden_size1) # first transform vertex features into hidden representations | |
# GWM parameters | |
self.W_a_main = nn.ModuleList( | |
[nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in | |
range(self.GNN_depth)]) | |
self.W_a_super = nn.ModuleList( | |
[nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in | |
range(self.GNN_depth)]) | |
self.W_main = nn.ModuleList( | |
[nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in | |
range(self.GNN_depth)]) | |
self.W_bmm = nn.ModuleList( | |
[nn.ModuleList([nn.Linear(self.hidden_size1, 1) for i in range(self.k_head)]) for i in | |
range(self.GNN_depth)]) | |
self.W_super = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
self.W_main_to_super = nn.ModuleList( | |
[nn.Linear(self.hidden_size1 * self.k_head, self.hidden_size1) for i in range(self.GNN_depth)]) | |
self.W_super_to_main = nn.ModuleList( | |
[nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
self.W_zm1 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
self.W_zm2 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
self.W_zs1 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
self.W_zs2 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
self.GRU_main = nn.GRUCell(self.hidden_size1, self.hidden_size1) | |
self.GRU_super = nn.GRUCell(self.hidden_size1, self.hidden_size1) | |
# WLN parameters | |
self.label_U2 = nn.ModuleList([nn.Linear(self.hidden_size1 + bond_fdim, self.hidden_size1) for i in | |
range(self.GNN_depth)]) # assume no edge feature transformation | |
self.label_U1 = nn.ModuleList( | |
[nn.Linear(self.hidden_size1 * 2, self.hidden_size1) for i in range(self.GNN_depth)]) | |
"""CNN-RNN Module""" | |
# CNN parameters | |
self.embed_seq = nn.Embedding(len(self.init_word_features), 20, padding_idx=0) | |
self.embed_seq.weight = nn.Parameter(self.init_word_features) | |
self.embed_seq.weight.requires_grad = False | |
self.conv_first = nn.Conv1d(20, self.hidden_size1, kernel_size=self.kernel_size, | |
padding=(self.kernel_size - 1) / 2) | |
self.conv_last = nn.Conv1d(self.hidden_size1, self.hidden_size1, kernel_size=self.kernel_size, | |
padding=(self.kernel_size - 1) / 2) | |
self.plain_CNN = nn.ModuleList([]) | |
for i in range(self.inner_CNN_depth): | |
self.plain_CNN.append(nn.Conv1d(self.hidden_size1, self.hidden_size1, kernel_size=self.kernel_size, | |
padding=(self.kernel_size - 1) / 2)) | |
"""Affinity Prediction Module""" | |
self.super_final = nn.Linear(self.hidden_size1, self.hidden_size2) | |
self.c_final = nn.Linear(self.hidden_size1, self.hidden_size2) | |
self.p_final = nn.Linear(self.hidden_size1, self.hidden_size2) | |
# DMA parameters | |
self.mc0 = nn.Linear(hidden_size2, hidden_size2) | |
self.mp0 = nn.Linear(hidden_size2, hidden_size2) | |
self.mc1 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
self.mp1 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
self.hc0 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
self.hp0 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
self.hc1 = nn.ModuleList([nn.Linear(self.hidden_size2, 1) for i in range(self.DMA_depth)]) | |
self.hp1 = nn.ModuleList([nn.Linear(self.hidden_size2, 1) for i in range(self.DMA_depth)]) | |
self.c_to_p_transform = nn.ModuleList( | |
[nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
self.p_to_c_transform = nn.ModuleList( | |
[nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
self.GRU_dma = nn.GRUCell(self.hidden_size2, self.hidden_size2) | |
# Output layer | |
self.W_out = nn.Linear(self.hidden_size2 * self.hidden_size2 * 2, 1) | |
"""Pairwise Interaction Prediction Module""" | |
self.pairwise_compound = nn.Linear(self.hidden_size1, self.hidden_size1) | |
self.pairwise_protein = nn.Linear(self.hidden_size1, self.hidden_size1) | |
def mask_softmax(self, a, mask, dim=-1): | |
a_max = torch.max(a, dim, keepdim=True)[0] | |
a_exp = torch.exp(a - a_max) | |
a_exp = a_exp * mask | |
a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6) | |
return a_softmax | |
def graph_conv_module(self, batch_size, vertex_mask, vertex, edge, atom_adj, bond_adj, nbs_mask): | |
n_vertex = vertex_mask.size(1) | |
# initial features | |
vertex_initial = torch.index_select(self.init_atom_features, 0, vertex.view(-1)) | |
vertex_initial = vertex_initial.view(batch_size, -1, atom_fdim) | |
edge_initial = torch.index_select(self.init_bond_features, 0, edge.view(-1)) | |
edge_initial = edge_initial.view(batch_size, -1, bond_fdim) | |
vertex_feature = F.leaky_relu(self.vertex_embedding(vertex_initial), 0.1) | |
super_feature = torch.sum(vertex_feature * vertex_mask.view(batch_size, -1, 1), dim=1, keepdim=True) | |
for GWM_iter in range(self.GNN_depth): | |
# prepare main node features | |
for k in range(self.k_head): | |
a_main = torch.tanh(self.W_a_main[GWM_iter][k](vertex_feature)) | |
a_super = torch.tanh(self.W_a_super[GWM_iter][k](super_feature)) | |
a = self.W_bmm[GWM_iter][k](a_main * super_feature) | |
attn = self.mask_softmax(a.view(batch_size, -1), vertex_mask).view(batch_size, -1, 1) | |
k_main_to_super = torch.bmm(attn.transpose(1, 2), self.W_main[GWM_iter][k](vertex_feature)) | |
if k == 0: | |
m_main_to_super = k_main_to_super | |
else: | |
m_main_to_super = torch.cat([m_main_to_super, k_main_to_super], dim=-1) # concat k-head | |
main_to_super = torch.tanh(self.W_main_to_super[GWM_iter](m_main_to_super)) | |
main_self = self.wln_unit(batch_size, vertex_mask, vertex_feature, edge_initial, atom_adj, bond_adj, | |
nbs_mask, GWM_iter) | |
super_to_main = torch.tanh(self.W_super_to_main[GWM_iter](super_feature)) | |
super_self = torch.tanh(self.W_super[GWM_iter](super_feature)) | |
# warp gate and GRU for update main node features, use main_self and super_to_main | |
z_main = torch.sigmoid(self.W_zm1[GWM_iter](main_self) + self.W_zm2[GWM_iter](super_to_main)) | |
hidden_main = (1 - z_main) * main_self + z_main * super_to_main | |
vertex_feature = self.GRU_main(hidden_main.view(-1, self.hidden_size1), | |
vertex_feature.view(-1, self.hidden_size1)) | |
vertex_feature = vertex_feature.view(batch_size, n_vertex, self.hidden_size1) | |
# warp gate and GRU for update super node features | |
z_supper = torch.sigmoid(self.W_zs1[GWM_iter](super_self) + self.W_zs2[GWM_iter](main_to_super)) | |
hidden_super = (1 - z_supper) * super_self + z_supper * main_to_super | |
super_feature = self.GRU_super(hidden_super.view(batch_size, self.hidden_size1), | |
super_feature.view(batch_size, self.hidden_size1)) | |
super_feature = super_feature.view(batch_size, 1, self.hidden_size1) | |
return vertex_feature, super_feature | |
def wln_unit(self, batch_size, vertex_mask, vertex_features, edge_initial, atom_adj, bond_adj, nbs_mask, GNN_iter): | |
n_vertex = vertex_mask.size(1) | |
n_nbs = nbs_mask.size(2) | |
vertex_mask = vertex_mask.view(batch_size, n_vertex, 1) | |
nbs_mask = nbs_mask.view(batch_size, n_vertex, n_nbs, 1) | |
vertex_nei = torch.index_select(vertex_features.view(-1, self.hidden_size1), 0, atom_adj).view(batch_size, | |
n_vertex, n_nbs, | |
self.hidden_size1) | |
edge_nei = torch.index_select(edge_initial.view(-1, bond_fdim), 0, bond_adj).view(batch_size, n_vertex, n_nbs, | |
bond_fdim) | |
# Weisfeiler Lehman relabelling | |
l_nei = torch.cat((vertex_nei, edge_nei), -1) | |
nei_label = F.leaky_relu(self.label_U2[GNN_iter](l_nei), 0.1) | |
nei_label = torch.sum(nei_label * nbs_mask, dim=-2) | |
new_label = torch.cat((vertex_features, nei_label), 2) | |
new_label = self.label_U1[GNN_iter](new_label) | |
vertex_features = F.leaky_relu(new_label, 0.1) | |
return vertex_features | |
def cnn_module(self, sequence): | |
ebd = self.embed_seq(sequence) | |
ebd = ebd.transpose(1, 2) | |
x = F.leaky_relu(self.conv_first(ebd), 0.1) | |
for i in range(self.inner_CNN_depth): | |
x = self.plain_CNN[i](x) | |
x = F.leaky_relu(x, 0.1) | |
x = F.leaky_relu(self.conv_last(x), 0.1) | |
H = x.transpose(1, 2) | |
# H, hidden = self.rnn(H) | |
return H | |
def pairwise_pred_module(self, batch_size, comp_feature, prot_feature, vertex_mask, seq_mask): | |
pairwise_c_feature = F.leaky_relu(self.pairwise_compound(comp_feature), 0.1) | |
pairwise_p_feature = F.leaky_relu(self.pairwise_protein(prot_feature), 0.1) | |
pairwise_pred = torch.matmul(pairwise_c_feature, pairwise_p_feature.transpose(1, 2)) | |
# TODO: difference between the pairwise_mask here and in the data? | |
pairwise_mask = torch.matmul(vertex_mask.view(batch_size, -1, 1), seq_mask.view(batch_size, 1, -1)) | |
pairwise_pred = pairwise_pred * pairwise_mask | |
return pairwise_pred | |
def affinity_pred_module(self, batch_size, comp_feature, prot_feature, super_feature, vertex_mask, seq_mask, | |
pairwise_pred): | |
comp_feature = F.leaky_relu(self.c_final(comp_feature), 0.1) | |
prot_feature = F.leaky_relu(self.p_final(prot_feature), 0.1) | |
super_feature = F.leaky_relu(self.super_final(super_feature.view(batch_size, -1)), 0.1) | |
cf, pf = self.dma_gru(batch_size, comp_feature, vertex_mask, prot_feature, seq_mask, pairwise_pred) | |
cf = torch.cat([cf.view(batch_size, -1), super_feature.view(batch_size, -1)], dim=1) | |
kroneck = F.leaky_relu( | |
torch.matmul(cf.view(batch_size, -1, 1), pf.view(batch_size, 1, -1)).view(batch_size, -1), 0.1) | |
affinity_pred = self.W_out(kroneck) | |
return affinity_pred | |
def dma_gru(self, batch_size, comp_feats, vertex_mask, prot_feats, seq_mask, pairwise_pred): | |
vertex_mask = vertex_mask.view(batch_size, -1, 1) | |
seq_mask = seq_mask.view(batch_size, -1, 1) | |
cf = torch.Tensor() | |
pf = torch.Tensor() | |
c0 = torch.sum(comp_feats * vertex_mask, dim=1) / torch.sum(vertex_mask, dim=1) | |
p0 = torch.sum(prot_feats * seq_mask, dim=1) / torch.sum(seq_mask, dim=1) | |
m = c0 * p0 | |
for DMA_iter in range(self.DMA_depth): | |
c_to_p = torch.matmul(pairwise_pred.transpose(1, 2), | |
F.tanh(self.c_to_p_transform[DMA_iter](comp_feats))) # batch * n_residue * hidden | |
p_to_c = torch.matmul(pairwise_pred, | |
F.tanh(self.p_to_c_transform[DMA_iter](prot_feats))) # batch * n_vertex * hidden | |
c_tmp = F.tanh(self.hc0[DMA_iter](comp_feats)) * F.tanh(self.mc1[DMA_iter](m)).view(batch_size, 1, | |
-1) * p_to_c | |
p_tmp = F.tanh(self.hp0[DMA_iter](prot_feats)) * F.tanh(self.mp1[DMA_iter](m)).view(batch_size, 1, | |
-1) * c_to_p | |
c_att = self.mask_softmax(self.hc1[DMA_iter](c_tmp).view(batch_size, -1), vertex_mask.view(batch_size, -1)) | |
p_att = self.mask_softmax(self.hp1[DMA_iter](p_tmp).view(batch_size, -1), seq_mask.view(batch_size, -1)) | |
cf = torch.sum(comp_feats * c_att.view(batch_size, -1, 1), dim=1) | |
pf = torch.sum(prot_feats * p_att.view(batch_size, -1, 1), dim=1) | |
m = self.GRU_dma(m, cf * pf) | |
return cf, pf | |
def forward(self, enc_drug, enc_protein): | |
vertex_mask, vertex, edge, atom_adj, bond_adj, nbs_mask = enc_drug | |
vertex, vertex_mask = vertex | |
edge, _ = edge | |
atom_adj, _ = atom_adj | |
bond_adj, _ = bond_adj | |
nbs_mask, _ = enc_drug | |
seq_mask, sequence = enc_protein | |
batch_size = vertex.size(0) | |
atom_feature, super_feature = self.graph_conv_module(batch_size, vertex_mask, vertex, edge, atom_adj, bond_adj, | |
nbs_mask) | |
prot_feature = self.cnn_module(sequence) | |
pairwise_pred = self.pairwise_pred_module(batch_size, atom_feature, prot_feature, vertex_mask, seq_mask) | |
affinity_pred = self.affinity_pred_module(batch_size, atom_feature, prot_feature, super_feature, vertex_mask, | |
seq_mask, pairwise_pred) | |
return affinity_pred # , pairwise_pred | |