Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
import math | |
from dgllife.model.gnn import GCN | |
from torch.nn.utils.weight_norm import weight_norm | |
class DrugBAN(nn.Module): | |
def __init__( | |
self, | |
drug_in_feats, | |
drug_embedding, | |
drug_hidden_feats, | |
protein_emb_dim, | |
num_filters, | |
kernel_size, | |
mlp_in_dim, | |
mlp_hidden_dim, | |
mlp_out_dim, | |
drug_padding, | |
protein_padding, | |
ban_heads, | |
): | |
super().__init__() | |
self.drug_extractor = MolecularGCN(in_feats=drug_in_feats, dim_embedding=drug_embedding, | |
padding=drug_padding, | |
hidden_feats=drug_hidden_feats) | |
self.protein_extractor = ProteinCNN(protein_emb_dim, num_filters, kernel_size, protein_padding) | |
self.bcn = weight_norm( | |
BANLayer(v_dim=drug_hidden_feats[-1], q_dim=num_filters[-1], h_dim=mlp_in_dim, h_out=ban_heads), | |
name='h_mat', dim=None) | |
self.mlp_classifier = MLPDecoder(mlp_in_dim, mlp_hidden_dim, mlp_out_dim) | |
def forward(self, bg_d, v_p): | |
v_d = self.drug_extractor(bg_d) | |
v_p = self.protein_extractor(v_p) | |
f, att = self.bcn(v_d, v_p) | |
score = self.mlp_classifier(f) | |
# if mode == "train": | |
# return v_d, v_p, f, score | |
# elif mode == "eval": | |
# return v_d, v_p, score, att | |
return score | |
class MolecularGCN(nn.Module): | |
def __init__(self, in_feats, dim_embedding=128, padding=True, hidden_feats=None, activation=None): | |
super().__init__() | |
self.init_transform = nn.Linear(in_feats, dim_embedding, bias=False) | |
if padding: | |
with torch.no_grad(): | |
self.init_transform.weight[-1].fill_(0) | |
self.gnn = GCN(in_feats=dim_embedding, hidden_feats=hidden_feats, activation=activation) | |
self.output_feats = hidden_feats[-1] | |
def forward(self, batch_graph): | |
node_feats = batch_graph.ndata.pop('h') | |
node_feats = self.init_transform(node_feats) | |
node_feats = self.gnn(batch_graph, node_feats) | |
batch_size = batch_graph.batch_size | |
node_feats = node_feats.view(batch_size, -1, self.output_feats) | |
return node_feats | |
class ProteinCNN(nn.Module): | |
def __init__(self, embedding_dim, num_filters, kernel_size, padding=True): | |
super().__init__() | |
if padding: | |
self.embedding = nn.Embedding(26, embedding_dim, padding_idx=0) | |
else: | |
self.embedding = nn.Embedding(26, embedding_dim) | |
in_ch = [embedding_dim] + num_filters | |
self.in_ch = in_ch[-1] | |
kernels = kernel_size | |
self.conv1 = nn.Conv1d(in_channels=in_ch[0], out_channels=in_ch[1], kernel_size=kernels[0]) | |
self.bn1 = nn.BatchNorm1d(in_ch[1]) | |
self.conv2 = nn.Conv1d(in_channels=in_ch[1], out_channels=in_ch[2], kernel_size=kernels[1]) | |
self.bn2 = nn.BatchNorm1d(in_ch[2]) | |
self.conv3 = nn.Conv1d(in_channels=in_ch[2], out_channels=in_ch[3], kernel_size=kernels[2]) | |
self.bn3 = nn.BatchNorm1d(in_ch[3]) | |
def forward(self, v): | |
v = self.embedding(v.long()) | |
v = v.transpose(2, 1) | |
v = self.bn1(F.relu(self.conv1(v))) | |
v = self.bn2(F.relu(self.conv2(v))) | |
v = self.bn3(F.relu(self.conv3(v))) | |
v = v.view(v.size(0), v.size(2), -1) | |
return v | |
class MLPDecoder(nn.Module): | |
def __init__(self, in_dim, hidden_dim, out_dim): | |
super().__init__() | |
self.fc1 = nn.Linear(in_dim, hidden_dim) | |
self.bn1 = nn.BatchNorm1d(hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, hidden_dim) | |
self.bn2 = nn.BatchNorm1d(hidden_dim) | |
self.fc3 = nn.Linear(hidden_dim, out_dim) | |
self.bn3 = nn.BatchNorm1d(out_dim) | |
# self.fc4 = nn.Linear(out_dim, binary) | |
def forward(self, x): | |
x = self.bn1(F.relu(self.fc1(x))) | |
x = self.bn2(F.relu(self.fc2(x))) | |
x = self.bn3(F.relu(self.fc3(x))) | |
# x = self.fc4(x) | |
return x | |
# noinspection PyTypeChecker | |
class SimpleClassifier(nn.Module): | |
def __init__(self, in_dim, hid_dim, out_dim, dropout): | |
super().__init__() | |
layers = [ | |
weight_norm(nn.Linear(in_dim, hid_dim), dim=None), | |
nn.ReLU(), | |
nn.Dropout(dropout, inplace=True), | |
weight_norm(nn.Linear(hid_dim, out_dim), dim=None) | |
] | |
self.main = nn.Sequential(*layers) | |
def forward(self, x): | |
logits = self.main(x) | |
return logits | |
class RandomLayer(nn.Module): | |
def __init__(self, input_dim_list, output_dim=256): | |
super().__init__() | |
self.input_num = len(input_dim_list) | |
self.output_dim = output_dim | |
self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)] | |
def forward(self, input_list): | |
return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)] | |
return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list)) | |
for single in return_list[1:]: | |
return_tensor = torch.mul(return_tensor, single) | |
return return_tensor | |
def cuda(self, *args): | |
super(RandomLayer, self).cuda(*args) | |
self.random_matrix = [val.cuda(*args) for val in self.random_matrix] | |
# noinspection PyTypeChecker | |
class BANLayer(nn.Module): | |
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3): | |
super().__init__() | |
self.c = 32 | |
self.k = k | |
self.v_dim = v_dim | |
self.q_dim = q_dim | |
self.h_dim = h_dim | |
self.h_out = h_out | |
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout) | |
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout) | |
# self.dropout = nn.Dropout(dropout[1]) | |
if 1 < k: | |
self.p_net = nn.AvgPool1d(self.k, stride=self.k) | |
if h_out <= self.c: | |
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) | |
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) | |
else: | |
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) | |
self.bn = nn.BatchNorm1d(h_dim) | |
def attention_pooling(self, v, q, att_map): | |
fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q)) | |
if 1 < self.k: | |
fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d | |
fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling | |
return fusion_logits | |
def forward(self, v, q, softmax=False): | |
v_num = v.size(1) | |
q_num = q.size(1) | |
if self.h_out <= self.c: | |
v_ = self.v_net(v) | |
q_ = self.q_net(q) | |
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias | |
else: | |
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3) | |
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) | |
d_ = torch.matmul(v_, q_) # b x h_dim x v x q | |
att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out | |
att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q | |
if softmax: | |
p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2) | |
att_maps = p.view(-1, self.h_out, v_num, q_num) | |
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :]) | |
for i in range(1, self.h_out): | |
logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :]) | |
logits += logits_i | |
logits = self.bn(logits) | |
return logits, att_maps | |
# noinspection PyTypeChecker | |
class FCNet(nn.Module): | |
"""Simple class for non-linear fully connect network | |
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py | |
""" | |
def __init__(self, dims, act='ReLU', dropout=0.0): | |
super().__init__() | |
layers = [] | |
for i in range(len(dims) - 2): | |
in_dim = dims[i] | |
out_dim = dims[i + 1] | |
if 0 < dropout: | |
layers.append(nn.Dropout(dropout)) | |
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) | |
if '' != act: | |
layers.append(getattr(nn, act)()) | |
if 0 < dropout: | |
layers.append(nn.Dropout(dropout)) | |
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) | |
if '' != act: | |
layers.append(getattr(nn, act)()) | |
self.main = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.main(x) | |
class BCNet(nn.Module): | |
"""Simple class for non-linear bilinear connect network | |
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/bc.py | |
""" | |
# noinspection PyTypeChecker | |
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=(0.2, 0.5), k=3): | |
super().__init__() | |
self.c = 32 | |
self.k = k | |
self.v_dim = v_dim | |
self.q_dim = q_dim | |
self.h_dim = h_dim | |
self.h_out = h_out | |
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0]) | |
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0]) | |
self.dropout = nn.Dropout(dropout[1]) # attention | |
if 1 < k: | |
self.p_net = nn.AvgPool1d(self.k, stride=self.k) | |
if h_out is None: | |
pass | |
elif h_out <= self.c: | |
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) | |
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) | |
else: | |
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) | |
def forward(self, v, q): | |
if self.h_out is None: | |
v_ = self.v_net(v) | |
q_ = self.q_net(q) | |
logits = torch.einsum('bvk,bqk->bvqk', (v_, q_)) | |
return logits | |
# low-rank bilinear pooling using einsum | |
elif self.h_out <= self.c: | |
v_ = self.dropout(self.v_net(v)) | |
q_ = self.q_net(q) | |
logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias | |
return logits # b x h_out x v x q | |
# batch outer product, linear projection | |
# memory efficient but slow computation | |
else: | |
v_ = self.dropout(self.v_net(v)).transpose(1, 2).unsqueeze(3) | |
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) | |
d_ = torch.matmul(v_, q_) # b x h_dim x v x q | |
logits = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out | |
return logits.transpose(2, 3).transpose(1, 2) # b x h_out x v x q | |
def forward_with_weights(self, v, q, w): | |
v_ = self.v_net(v) # b x v x d | |
q_ = self.q_net(q) # b x q x d | |
logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_)) | |
if 1 < self.k: | |
logits = logits.unsqueeze(1) # b x 1 x d | |
logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling | |
return logits | |
def drug_featurizer(smiles, max_drug_nodes=290): | |
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer | |
from deepscreen.utils import get_logger | |
log = get_logger(__name__) | |
try: | |
v_d = smiles_to_bigraph(smiles=smiles, | |
node_featurizer=CanonicalAtomFeaturizer(), | |
edge_featurizer=CanonicalBondFeaturizer(self_loop=True), | |
add_self_loop=True) | |
if v_d is None: | |
return None | |
actual_node_feats = v_d.ndata.pop('h') | |
num_actual_nodes = actual_node_feats.shape[0] | |
num_virtual_nodes = max_drug_nodes - num_actual_nodes | |
virtual_node_bit = torch.zeros([num_actual_nodes, 1]) | |
actual_node_feats = torch.cat((actual_node_feats, virtual_node_bit), 1) | |
v_d.ndata['h'] = actual_node_feats | |
virtual_node_feat = torch.cat( | |
(torch.zeros(num_virtual_nodes, 74), torch.ones(num_virtual_nodes, 1)), 1 | |
) | |
v_d.add_nodes(num_virtual_nodes, {"h": virtual_node_feat}) | |
v_d = v_d.add_self_loop() | |
return v_d | |
except Exception as e: | |
log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}") | |
return None | |