import torch.nn as nn import torch.nn.functional as F import torch import math from dgllife.model.gnn import GCN from torch.nn.utils.weight_norm import weight_norm class DrugBAN(nn.Module): def __init__( self, drug_in_feats, drug_embedding, drug_hidden_feats, protein_emb_dim, num_filters, kernel_size, mlp_in_dim, mlp_hidden_dim, mlp_out_dim, drug_padding, protein_padding, ban_heads, ): super().__init__() self.drug_extractor = MolecularGCN(in_feats=drug_in_feats, dim_embedding=drug_embedding, padding=drug_padding, hidden_feats=drug_hidden_feats) self.protein_extractor = ProteinCNN(protein_emb_dim, num_filters, kernel_size, protein_padding) self.bcn = weight_norm( BANLayer(v_dim=drug_hidden_feats[-1], q_dim=num_filters[-1], h_dim=mlp_in_dim, h_out=ban_heads), name='h_mat', dim=None) self.mlp_classifier = MLPDecoder(mlp_in_dim, mlp_hidden_dim, mlp_out_dim) def forward(self, bg_d, v_p): v_d = self.drug_extractor(bg_d) v_p = self.protein_extractor(v_p) f, att = self.bcn(v_d, v_p) score = self.mlp_classifier(f) # if mode == "train": # return v_d, v_p, f, score # elif mode == "eval": # return v_d, v_p, score, att return score class MolecularGCN(nn.Module): def __init__(self, in_feats, dim_embedding=128, padding=True, hidden_feats=None, activation=None): super().__init__() self.init_transform = nn.Linear(in_feats, dim_embedding, bias=False) if padding: with torch.no_grad(): self.init_transform.weight[-1].fill_(0) self.gnn = GCN(in_feats=dim_embedding, hidden_feats=hidden_feats, activation=activation) self.output_feats = hidden_feats[-1] def forward(self, batch_graph): node_feats = batch_graph.ndata.pop('h') node_feats = self.init_transform(node_feats) node_feats = self.gnn(batch_graph, node_feats) batch_size = batch_graph.batch_size node_feats = node_feats.view(batch_size, -1, self.output_feats) return node_feats class ProteinCNN(nn.Module): def __init__(self, embedding_dim, num_filters, kernel_size, padding=True): super().__init__() if padding: self.embedding = nn.Embedding(26, embedding_dim, padding_idx=0) else: self.embedding = nn.Embedding(26, embedding_dim) in_ch = [embedding_dim] + num_filters self.in_ch = in_ch[-1] kernels = kernel_size self.conv1 = nn.Conv1d(in_channels=in_ch[0], out_channels=in_ch[1], kernel_size=kernels[0]) self.bn1 = nn.BatchNorm1d(in_ch[1]) self.conv2 = nn.Conv1d(in_channels=in_ch[1], out_channels=in_ch[2], kernel_size=kernels[1]) self.bn2 = nn.BatchNorm1d(in_ch[2]) self.conv3 = nn.Conv1d(in_channels=in_ch[2], out_channels=in_ch[3], kernel_size=kernels[2]) self.bn3 = nn.BatchNorm1d(in_ch[3]) def forward(self, v): v = self.embedding(v.long()) v = v.transpose(2, 1) v = self.bn1(F.relu(self.conv1(v))) v = self.bn2(F.relu(self.conv2(v))) v = self.bn3(F.relu(self.conv3(v))) v = v.view(v.size(0), v.size(2), -1) return v class MLPDecoder(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim): super().__init__() self.fc1 = nn.Linear(in_dim, hidden_dim) self.bn1 = nn.BatchNorm1d(hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.bn2 = nn.BatchNorm1d(hidden_dim) self.fc3 = nn.Linear(hidden_dim, out_dim) self.bn3 = nn.BatchNorm1d(out_dim) # self.fc4 = nn.Linear(out_dim, binary) def forward(self, x): x = self.bn1(F.relu(self.fc1(x))) x = self.bn2(F.relu(self.fc2(x))) x = self.bn3(F.relu(self.fc3(x))) # x = self.fc4(x) return x # noinspection PyTypeChecker class SimpleClassifier(nn.Module): def __init__(self, in_dim, hid_dim, out_dim, dropout): super().__init__() layers = [ weight_norm(nn.Linear(in_dim, hid_dim), dim=None), nn.ReLU(), nn.Dropout(dropout, inplace=True), weight_norm(nn.Linear(hid_dim, out_dim), dim=None) ] self.main = nn.Sequential(*layers) def forward(self, x): logits = self.main(x) return logits class RandomLayer(nn.Module): def __init__(self, input_dim_list, output_dim=256): super().__init__() self.input_num = len(input_dim_list) self.output_dim = output_dim self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)] def forward(self, input_list): return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)] return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list)) for single in return_list[1:]: return_tensor = torch.mul(return_tensor, single) return return_tensor def cuda(self, *args): super(RandomLayer, self).cuda(*args) self.random_matrix = [val.cuda(*args) for val in self.random_matrix] # noinspection PyTypeChecker class BANLayer(nn.Module): def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3): super().__init__() self.c = 32 self.k = k self.v_dim = v_dim self.q_dim = q_dim self.h_dim = h_dim self.h_out = h_out self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout) self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout) # self.dropout = nn.Dropout(dropout[1]) if 1 < k: self.p_net = nn.AvgPool1d(self.k, stride=self.k) if h_out <= self.c: self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) else: self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) self.bn = nn.BatchNorm1d(h_dim) def attention_pooling(self, v, q, att_map): fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q)) if 1 < self.k: fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling return fusion_logits def forward(self, v, q, softmax=False): v_num = v.size(1) q_num = q.size(1) if self.h_out <= self.c: v_ = self.v_net(v) q_ = self.q_net(q) att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias else: v_ = self.v_net(v).transpose(1, 2).unsqueeze(3) q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) d_ = torch.matmul(v_, q_) # b x h_dim x v x q att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q if softmax: p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2) att_maps = p.view(-1, self.h_out, v_num, q_num) logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :]) for i in range(1, self.h_out): logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :]) logits += logits_i logits = self.bn(logits) return logits, att_maps # noinspection PyTypeChecker class FCNet(nn.Module): """Simple class for non-linear fully connect network Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py """ def __init__(self, dims, act='ReLU', dropout=0.0): super().__init__() layers = [] for i in range(len(dims) - 2): in_dim = dims[i] out_dim = dims[i + 1] if 0 < dropout: layers.append(nn.Dropout(dropout)) layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) if '' != act: layers.append(getattr(nn, act)()) if 0 < dropout: layers.append(nn.Dropout(dropout)) layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) if '' != act: layers.append(getattr(nn, act)()) self.main = nn.Sequential(*layers) def forward(self, x): return self.main(x) class BCNet(nn.Module): """Simple class for non-linear bilinear connect network Modified from https://github.com/jnhwkim/ban-vqa/blob/master/bc.py """ # noinspection PyTypeChecker def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=(0.2, 0.5), k=3): super().__init__() self.c = 32 self.k = k self.v_dim = v_dim self.q_dim = q_dim self.h_dim = h_dim self.h_out = h_out self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0]) self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0]) self.dropout = nn.Dropout(dropout[1]) # attention if 1 < k: self.p_net = nn.AvgPool1d(self.k, stride=self.k) if h_out is None: pass elif h_out <= self.c: self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) else: self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) def forward(self, v, q): if self.h_out is None: v_ = self.v_net(v) q_ = self.q_net(q) logits = torch.einsum('bvk,bqk->bvqk', (v_, q_)) return logits # low-rank bilinear pooling using einsum elif self.h_out <= self.c: v_ = self.dropout(self.v_net(v)) q_ = self.q_net(q) logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias return logits # b x h_out x v x q # batch outer product, linear projection # memory efficient but slow computation else: v_ = self.dropout(self.v_net(v)).transpose(1, 2).unsqueeze(3) q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) d_ = torch.matmul(v_, q_) # b x h_dim x v x q logits = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out return logits.transpose(2, 3).transpose(1, 2) # b x h_out x v x q def forward_with_weights(self, v, q, w): v_ = self.v_net(v) # b x v x d q_ = self.q_net(q) # b x q x d logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_)) if 1 < self.k: logits = logits.unsqueeze(1) # b x 1 x d logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling return logits def drug_featurizer(smiles, max_drug_nodes=290): from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer from deepscreen.utils import get_logger log = get_logger(__name__) try: v_d = smiles_to_bigraph(smiles=smiles, node_featurizer=CanonicalAtomFeaturizer(), edge_featurizer=CanonicalBondFeaturizer(self_loop=True), add_self_loop=True) if v_d is None: return None actual_node_feats = v_d.ndata.pop('h') num_actual_nodes = actual_node_feats.shape[0] num_virtual_nodes = max_drug_nodes - num_actual_nodes virtual_node_bit = torch.zeros([num_actual_nodes, 1]) actual_node_feats = torch.cat((actual_node_feats, virtual_node_bit), 1) v_d.ndata['h'] = actual_node_feats virtual_node_feat = torch.cat( (torch.zeros(num_virtual_nodes, 74), torch.ones(num_virtual_nodes, 1)), 1 ) v_d.add_nodes(num_virtual_nodes, {"h": virtual_node_feat}) v_d = v_d.add_self_loop() return v_d except Exception as e: log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}") return None