""" MGraphDTA: Deep Multiscale Graph Neural Network for Explainable Drug-target binding affinity Prediction """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from rdkit import Chem from torch.nn.modules.batchnorm import _BatchNorm import torch_geometric.nn as gnn from torch import Tensor from collections import OrderedDict from deepscreen.data.featurizers.categorical import one_of_k_encoding, one_of_k_encoding_unk class MGraphDTA(nn.Module): def __init__(self, block_num, vocab_protein_size, embedding_size=128, filter_num=32): super().__init__() self.protein_encoder = TargetRepresentation(block_num, vocab_protein_size, embedding_size) self.ligand_encoder = GraphDenseNet(num_input_features=87, out_dim=filter_num * 3, block_config=[8, 8, 8], bn_sizes=[2, 2, 2]) self.classifier = nn.Sequential( nn.Linear(filter_num * 3 * 2, 1024), nn.ReLU(), nn.Dropout(0.1), nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(0.1), nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.1) ) def forward(self, emb_drug, emb_protein): protein_x = self.protein_encoder(emb_protein) ligand_x = self.ligand_encoder(emb_drug) x = torch.cat([protein_x, ligand_x], dim=-1) x = self.classifier(x) return x class Conv1dReLU(nn.Module): """ kernel_size=3, stride=1, padding=1 kernel_size=5, stride=1, padding=2 kernel_size=7, stride=1, padding=3 """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): super().__init__() self.inc = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), nn.ReLU() ) def forward(self, x): return self.inc(x) class LinearReLU(nn.Module): def __init__(self, in_features, out_features, bias=True): super().__init__() self.inc = nn.Sequential( nn.Linear(in_features=in_features, out_features=out_features, bias=bias), nn.ReLU() ) def forward(self, x): return self.inc(x) class StackCNN(nn.Module): def __init__(self, layer_num, in_channels, out_channels, kernel_size, stride=1, padding=0): super().__init__() self.inc = nn.Sequential(OrderedDict([('conv_layer0', Conv1dReLU(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding))])) for layer_idx in range(layer_num - 1): self.inc.add_module('conv_layer%d' % (layer_idx + 1), Conv1dReLU(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)) self.inc.add_module('pool_layer', nn.AdaptiveMaxPool1d(1)) def forward(self, x): return self.inc(x).squeeze(-1) class TargetRepresentation(nn.Module): def __init__(self, block_num, vocab_size, embedding_num): super().__init__() self.embed = nn.Embedding(vocab_size, embedding_num, padding_idx=0) self.block_list = nn.ModuleList() for block_idx in range(block_num): self.block_list.append( StackCNN(block_idx + 1, embedding_num, 96, 3) ) self.linear = nn.Linear(block_num * 96, 96) def forward(self, x): x = self.embed(x).permute(0, 2, 1) feats = [block(x) for block in self.block_list] x = torch.cat(feats, -1) x = self.linear(x) return x class NodeLevelBatchNorm(_BatchNorm): r""" Applies Batch Normalization over a batch of graph data. Shape: - Input: [batch_nodes_dim, node_feature_dim] - Output: [batch_nodes_dim, node_feature_dim] batch_nodes_dim: all nodes of a batch graph """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(NodeLevelBatchNorm, self).__init__( num_features, eps, momentum, affine, track_running_stats) def _check_input_dim(self, input): if input.dim() != 2: raise ValueError('expected 2D input (got {}D input)' .format(input.dim())) def forward(self, input): self._check_input_dim(input) if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: exponential_average_factor = self.momentum return torch.functional.F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) def extra_repr(self): return 'num_features={num_features}, eps={eps}, ' \ 'affine={affine}'.format(**self.__dict__) class GraphConvBn(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = gnn.GraphConv(in_channels, out_channels) self.norm = NodeLevelBatchNorm(out_channels) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch data.x = F.relu(self.norm(self.conv(x, edge_index))) return data class DenseLayer(nn.Module): def __init__(self, num_input_features, growth_rate=32, bn_size=4): super().__init__() self.conv1 = GraphConvBn(num_input_features, int(growth_rate * bn_size)) self.conv2 = GraphConvBn(int(growth_rate * bn_size), growth_rate) def bn_function(self, data): concated_features = torch.cat(data.x, 1) data.x = concated_features data = self.conv1(data) return data def forward(self, data): if isinstance(data.x, Tensor): data.x = [data.x] data = self.bn_function(data) data = self.conv2(data) return data class DenseBlock(nn.ModuleDict): def __init__(self, num_layers, num_input_features, growth_rate=32, bn_size=4): super().__init__() for i in range(num_layers): layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size) self.add_module('layer%d' % (i + 1), layer) def forward(self, data): features = [data.x] for name, layer in self.items(): data = layer(data) features.append(data.x) data.x = features data.x = torch.cat(data.x, 1) return data class GraphDenseNet(nn.Module): def __init__(self, num_input_features, out_dim, growth_rate=32, block_config=(3, 3, 3, 3), bn_sizes=(2, 3, 4, 4)): super().__init__() self.features = nn.Sequential(OrderedDict([('conv0', GraphConvBn(num_input_features, 32))])) num_input_features = 32 for i, num_layers in enumerate(block_config): block = DenseBlock( num_layers, num_input_features, growth_rate=growth_rate, bn_size=bn_sizes[i] ) self.features.add_module('block%d' % (i + 1), block) num_input_features += int(num_layers * growth_rate) trans = GraphConvBn(num_input_features, num_input_features // 2) self.features.add_module("transition%d" % (i + 1), trans) num_input_features = num_input_features // 2 self.classifier = nn.Linear(num_input_features, out_dim) def forward(self, data): data = self.features(data) x = gnn.global_mean_pool(data.x, data.batch) x = self.classifier(x) return x def atom_features(atom): encoding = one_of_k_encoding_unk(atom.GetSymbol(), ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) encoding += one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + one_of_k_encoding_unk( atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) encoding += one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) encoding += one_of_k_encoding_unk(atom.GetHybridization(), [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, 'other']) encoding += [atom.GetIsAromatic()] try: encoding += one_of_k_encoding_unk( atom.GetProp('_CIPCode'), ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] except: encoding += [0, 0] + [atom.HasProp('_ChiralityPossible')] return np.array(encoding)