libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
9.68 kB
"""
MGraphDTA: Deep Multiscale Graph Neural Network for Explainable Drug-target binding affinity Prediction
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from rdkit import Chem
from torch.nn.modules.batchnorm import _BatchNorm
import torch_geometric.nn as gnn
from torch import Tensor
from collections import OrderedDict
from deepscreen.data.featurizers.categorical import one_of_k_encoding, one_of_k_encoding_unk
class MGraphDTA(nn.Module):
def __init__(self, block_num, vocab_protein_size, embedding_size=128, filter_num=32):
super().__init__()
self.protein_encoder = TargetRepresentation(block_num, vocab_protein_size, embedding_size)
self.ligand_encoder = GraphDenseNet(num_input_features=87,
out_dim=filter_num * 3,
block_config=[8, 8, 8],
bn_sizes=[2, 2, 2])
self.classifier = nn.Sequential(
nn.Linear(filter_num * 3 * 2, 1024),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1)
)
def forward(self, emb_drug, emb_protein):
protein_x = self.protein_encoder(emb_protein)
ligand_x = self.ligand_encoder(emb_drug)
x = torch.cat([protein_x, ligand_x], dim=-1)
x = self.classifier(x)
return x
class Conv1dReLU(nn.Module):
"""
kernel_size=3, stride=1, padding=1
kernel_size=5, stride=1, padding=2
kernel_size=7, stride=1, padding=3
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.inc = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding),
nn.ReLU()
)
def forward(self, x):
return self.inc(x)
class LinearReLU(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.inc = nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features, bias=bias),
nn.ReLU()
)
def forward(self, x):
return self.inc(x)
class StackCNN(nn.Module):
def __init__(self, layer_num, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.inc = nn.Sequential(OrderedDict([('conv_layer0',
Conv1dReLU(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding))]))
for layer_idx in range(layer_num - 1):
self.inc.add_module('conv_layer%d' % (layer_idx + 1),
Conv1dReLU(out_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding))
self.inc.add_module('pool_layer', nn.AdaptiveMaxPool1d(1))
def forward(self, x):
return self.inc(x).squeeze(-1)
class TargetRepresentation(nn.Module):
def __init__(self, block_num, vocab_size, embedding_num):
super().__init__()
self.embed = nn.Embedding(vocab_size, embedding_num, padding_idx=0)
self.block_list = nn.ModuleList()
for block_idx in range(block_num):
self.block_list.append(
StackCNN(block_idx + 1, embedding_num, 96, 3)
)
self.linear = nn.Linear(block_num * 96, 96)
def forward(self, x):
x = self.embed(x).permute(0, 2, 1)
feats = [block(x) for block in self.block_list]
x = torch.cat(feats, -1)
x = self.linear(x)
return x
class NodeLevelBatchNorm(_BatchNorm):
r"""
Applies Batch Normalization over a batch of graph data.
Shape:
- Input: [batch_nodes_dim, node_feature_dim]
- Output: [batch_nodes_dim, node_feature_dim]
batch_nodes_dim: all nodes of a batch graph
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(NodeLevelBatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def _check_input_dim(self, input):
if input.dim() != 2:
raise ValueError('expected 2D input (got {}D input)'
.format(input.dim()))
def forward(self, input):
self._check_input_dim(input)
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None:
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return torch.functional.F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
def extra_repr(self):
return 'num_features={num_features}, eps={eps}, ' \
'affine={affine}'.format(**self.__dict__)
class GraphConvBn(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = gnn.GraphConv(in_channels, out_channels)
self.norm = NodeLevelBatchNorm(out_channels)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
data.x = F.relu(self.norm(self.conv(x, edge_index)))
return data
class DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate=32, bn_size=4):
super().__init__()
self.conv1 = GraphConvBn(num_input_features, int(growth_rate * bn_size))
self.conv2 = GraphConvBn(int(growth_rate * bn_size), growth_rate)
def bn_function(self, data):
concated_features = torch.cat(data.x, 1)
data.x = concated_features
data = self.conv1(data)
return data
def forward(self, data):
if isinstance(data.x, Tensor):
data.x = [data.x]
data = self.bn_function(data)
data = self.conv2(data)
return data
class DenseBlock(nn.ModuleDict):
def __init__(self, num_layers, num_input_features, growth_rate=32, bn_size=4):
super().__init__()
for i in range(num_layers):
layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size)
self.add_module('layer%d' % (i + 1), layer)
def forward(self, data):
features = [data.x]
for name, layer in self.items():
data = layer(data)
features.append(data.x)
data.x = features
data.x = torch.cat(data.x, 1)
return data
class GraphDenseNet(nn.Module):
def __init__(self, num_input_features, out_dim, growth_rate=32, block_config=(3, 3, 3, 3), bn_sizes=(2, 3, 4, 4)):
super().__init__()
self.features = nn.Sequential(OrderedDict([('conv0', GraphConvBn(num_input_features, 32))]))
num_input_features = 32
for i, num_layers in enumerate(block_config):
block = DenseBlock(
num_layers, num_input_features, growth_rate=growth_rate, bn_size=bn_sizes[i]
)
self.features.add_module('block%d' % (i + 1), block)
num_input_features += int(num_layers * growth_rate)
trans = GraphConvBn(num_input_features, num_input_features // 2)
self.features.add_module("transition%d" % (i + 1), trans)
num_input_features = num_input_features // 2
self.classifier = nn.Linear(num_input_features, out_dim)
def forward(self, data):
data = self.features(data)
x = gnn.global_mean_pool(data.x, data.batch)
x = self.classifier(x)
return x
def atom_features(atom):
encoding = one_of_k_encoding_unk(atom.GetSymbol(),
['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As',
'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti',
'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg',
'Pb', 'Unknown'])
encoding += one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + one_of_k_encoding_unk(
atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
encoding += one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
encoding += one_of_k_encoding_unk(atom.GetHybridization(), [
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2, 'other'])
encoding += [atom.GetIsAromatic()]
try:
encoding += one_of_k_encoding_unk(
atom.GetProp('_CIPCode'),
['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
except:
encoding += [0, 0] + [atom.HasProp('_ChiralityPossible')]
return np.array(encoding)