libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
12.5 kB
import torch.nn as nn
import torch.nn.functional as F
import torch
import math
from dgllife.model.gnn import GCN
from torch.nn.utils.weight_norm import weight_norm
class DrugBAN(nn.Module):
def __init__(
self,
drug_in_feats,
drug_embedding,
drug_hidden_feats,
protein_emb_dim,
num_filters,
kernel_size,
mlp_in_dim,
mlp_hidden_dim,
mlp_out_dim,
drug_padding,
protein_padding,
ban_heads,
):
super().__init__()
self.drug_extractor = MolecularGCN(in_feats=drug_in_feats, dim_embedding=drug_embedding,
padding=drug_padding,
hidden_feats=drug_hidden_feats)
self.protein_extractor = ProteinCNN(protein_emb_dim, num_filters, kernel_size, protein_padding)
self.bcn = weight_norm(
BANLayer(v_dim=drug_hidden_feats[-1], q_dim=num_filters[-1], h_dim=mlp_in_dim, h_out=ban_heads),
name='h_mat', dim=None)
self.mlp_classifier = MLPDecoder(mlp_in_dim, mlp_hidden_dim, mlp_out_dim)
def forward(self, bg_d, v_p):
v_d = self.drug_extractor(bg_d)
v_p = self.protein_extractor(v_p)
f, att = self.bcn(v_d, v_p)
score = self.mlp_classifier(f)
# if mode == "train":
# return v_d, v_p, f, score
# elif mode == "eval":
# return v_d, v_p, score, att
return score
class MolecularGCN(nn.Module):
def __init__(self, in_feats, dim_embedding=128, padding=True, hidden_feats=None, activation=None):
super().__init__()
self.init_transform = nn.Linear(in_feats, dim_embedding, bias=False)
if padding:
with torch.no_grad():
self.init_transform.weight[-1].fill_(0)
self.gnn = GCN(in_feats=dim_embedding, hidden_feats=hidden_feats, activation=activation)
self.output_feats = hidden_feats[-1]
def forward(self, batch_graph):
node_feats = batch_graph.ndata.pop('h')
node_feats = self.init_transform(node_feats)
node_feats = self.gnn(batch_graph, node_feats)
batch_size = batch_graph.batch_size
node_feats = node_feats.view(batch_size, -1, self.output_feats)
return node_feats
class ProteinCNN(nn.Module):
def __init__(self, embedding_dim, num_filters, kernel_size, padding=True):
super().__init__()
if padding:
self.embedding = nn.Embedding(26, embedding_dim, padding_idx=0)
else:
self.embedding = nn.Embedding(26, embedding_dim)
in_ch = [embedding_dim] + num_filters
self.in_ch = in_ch[-1]
kernels = kernel_size
self.conv1 = nn.Conv1d(in_channels=in_ch[0], out_channels=in_ch[1], kernel_size=kernels[0])
self.bn1 = nn.BatchNorm1d(in_ch[1])
self.conv2 = nn.Conv1d(in_channels=in_ch[1], out_channels=in_ch[2], kernel_size=kernels[1])
self.bn2 = nn.BatchNorm1d(in_ch[2])
self.conv3 = nn.Conv1d(in_channels=in_ch[2], out_channels=in_ch[3], kernel_size=kernels[2])
self.bn3 = nn.BatchNorm1d(in_ch[3])
def forward(self, v):
v = self.embedding(v.long())
v = v.transpose(2, 1)
v = self.bn1(F.relu(self.conv1(v)))
v = self.bn2(F.relu(self.conv2(v)))
v = self.bn3(F.relu(self.conv3(v)))
v = v.view(v.size(0), v.size(2), -1)
return v
class MLPDecoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.bn1 = nn.BatchNorm1d(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.bn2 = nn.BatchNorm1d(hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
self.bn3 = nn.BatchNorm1d(out_dim)
# self.fc4 = nn.Linear(out_dim, binary)
def forward(self, x):
x = self.bn1(F.relu(self.fc1(x)))
x = self.bn2(F.relu(self.fc2(x)))
x = self.bn3(F.relu(self.fc3(x)))
# x = self.fc4(x)
return x
# noinspection PyTypeChecker
class SimpleClassifier(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, dropout):
super().__init__()
layers = [
weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
nn.ReLU(),
nn.Dropout(dropout, inplace=True),
weight_norm(nn.Linear(hid_dim, out_dim), dim=None)
]
self.main = nn.Sequential(*layers)
def forward(self, x):
logits = self.main(x)
return logits
class RandomLayer(nn.Module):
def __init__(self, input_dim_list, output_dim=256):
super().__init__()
self.input_num = len(input_dim_list)
self.output_dim = output_dim
self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)]
def forward(self, input_list):
return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list))
for single in return_list[1:]:
return_tensor = torch.mul(return_tensor, single)
return return_tensor
def cuda(self, *args):
super(RandomLayer, self).cuda(*args)
self.random_matrix = [val.cuda(*args) for val in self.random_matrix]
# noinspection PyTypeChecker
class BANLayer(nn.Module):
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
super().__init__()
self.c = 32
self.k = k
self.v_dim = v_dim
self.q_dim = q_dim
self.h_dim = h_dim
self.h_out = h_out
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
# self.dropout = nn.Dropout(dropout[1])
if 1 < k:
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
if h_out <= self.c:
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
else:
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
self.bn = nn.BatchNorm1d(h_dim)
def attention_pooling(self, v, q, att_map):
fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
if 1 < self.k:
fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
return fusion_logits
def forward(self, v, q, softmax=False):
v_num = v.size(1)
q_num = q.size(1)
if self.h_out <= self.c:
v_ = self.v_net(v)
q_ = self.q_net(q)
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
else:
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
if softmax:
p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
att_maps = p.view(-1, self.h_out, v_num, q_num)
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
for i in range(1, self.h_out):
logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
logits += logits_i
logits = self.bn(logits)
return logits, att_maps
# noinspection PyTypeChecker
class FCNet(nn.Module):
"""Simple class for non-linear fully connect network
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
"""
def __init__(self, dims, act='ReLU', dropout=0.0):
super().__init__()
layers = []
for i in range(len(dims) - 2):
in_dim = dims[i]
out_dim = dims[i + 1]
if 0 < dropout:
layers.append(nn.Dropout(dropout))
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
if '' != act:
layers.append(getattr(nn, act)())
if 0 < dropout:
layers.append(nn.Dropout(dropout))
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
if '' != act:
layers.append(getattr(nn, act)())
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
class BCNet(nn.Module):
"""Simple class for non-linear bilinear connect network
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/bc.py
"""
# noinspection PyTypeChecker
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=(0.2, 0.5), k=3):
super().__init__()
self.c = 32
self.k = k
self.v_dim = v_dim
self.q_dim = q_dim
self.h_dim = h_dim
self.h_out = h_out
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0])
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0])
self.dropout = nn.Dropout(dropout[1]) # attention
if 1 < k:
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
if h_out is None:
pass
elif h_out <= self.c:
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
else:
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
def forward(self, v, q):
if self.h_out is None:
v_ = self.v_net(v)
q_ = self.q_net(q)
logits = torch.einsum('bvk,bqk->bvqk', (v_, q_))
return logits
# low-rank bilinear pooling using einsum
elif self.h_out <= self.c:
v_ = self.dropout(self.v_net(v))
q_ = self.q_net(q)
logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
return logits # b x h_out x v x q
# batch outer product, linear projection
# memory efficient but slow computation
else:
v_ = self.dropout(self.v_net(v)).transpose(1, 2).unsqueeze(3)
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
logits = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
return logits.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
def forward_with_weights(self, v, q, w):
v_ = self.v_net(v) # b x v x d
q_ = self.q_net(q) # b x q x d
logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))
if 1 < self.k:
logits = logits.unsqueeze(1) # b x 1 x d
logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling
return logits
def drug_featurizer(smiles, max_drug_nodes=290):
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from deepscreen.utils import get_logger
log = get_logger(__name__)
try:
v_d = smiles_to_bigraph(smiles=smiles,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=CanonicalBondFeaturizer(self_loop=True),
add_self_loop=True)
if v_d is None:
return None
actual_node_feats = v_d.ndata.pop('h')
num_actual_nodes = actual_node_feats.shape[0]
num_virtual_nodes = max_drug_nodes - num_actual_nodes
virtual_node_bit = torch.zeros([num_actual_nodes, 1])
actual_node_feats = torch.cat((actual_node_feats, virtual_node_bit), 1)
v_d.ndata['h'] = actual_node_feats
virtual_node_feat = torch.cat(
(torch.zeros(num_virtual_nodes, 74), torch.ones(num_virtual_nodes, 1)), 1
)
v_d.add_nodes(num_virtual_nodes, {"h": virtual_node_feat})
v_d = v_d.add_self_loop()
return v_d
except Exception as e:
log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}")
return None