libokj's picture
Upload 299 files
22761bf verified
raw
history blame
11.1 kB
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from rdkit import Chem
from rdkit.Chem import AllChem
class BACPI(nn.Module):
def __init__(
self,
n_atom,
n_amino,
comp_dim,
prot_dim,
gat_dim,
num_head,
dropout,
alpha,
window,
layer_cnn,
latent_dim,
):
super().__init__()
self.embedding_layer_atom = nn.Embedding(n_atom + 1, comp_dim)
self.embedding_layer_amino = nn.Embedding(n_amino + 1, prot_dim)
self.dropout = dropout
self.alpha = alpha
self.layer_cnn = layer_cnn
self.gat_layers = [GATLayer(comp_dim, gat_dim, dropout=dropout, alpha=alpha, concat=True)
for _ in range(num_head)]
for i, layer in enumerate(self.gat_layers):
self.add_module('gat_layer_{}'.format(i), layer)
self.gat_out = GATLayer(gat_dim * num_head, comp_dim, dropout=dropout, alpha=alpha, concat=False)
self.W_comp = nn.Linear(comp_dim, latent_dim)
self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2 * window + 1,
stride=1, padding=window) for _ in range(layer_cnn)])
self.W_prot = nn.Linear(prot_dim, latent_dim)
self.fp0 = nn.Parameter(torch.empty(size=(1024, latent_dim)))
nn.init.xavier_uniform_(self.fp0, gain=1.414)
self.fp1 = nn.Parameter(torch.empty(size=(latent_dim, latent_dim)))
nn.init.xavier_uniform_(self.fp1, gain=1.414)
self.bidat_num = 4
self.U = nn.ParameterList([
nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) for _ in range(self.bidat_num)
])
for i in range(self.bidat_num):
nn.init.xavier_uniform_(self.U[i], gain=1.414)
self.transform_c2p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)])
self.transform_p2c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)])
self.bihidden_c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)])
self.bihidden_p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)])
self.biatt_c = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)])
self.biatt_p = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)])
self.comb_c = nn.Linear(latent_dim * self.bidat_num, latent_dim)
self.comb_p = nn.Linear(latent_dim * self.bidat_num, latent_dim)
def comp_gat(self, atoms, atoms_mask, adj):
atoms_vector = self.embedding_layer_atom(atoms)
atoms_multi_head = torch.cat([gat(atoms_vector, adj) for gat in self.gat_layers], dim=2)
atoms_vector = F.elu(self.gat_out(atoms_multi_head, adj))
atoms_vector = F.leaky_relu(self.W_comp(atoms_vector), self.alpha)
return atoms_vector
def prot_cnn(self, amino, amino_mask):
amino_vector = self.embedding_layer_amino(amino)
amino_vector = torch.unsqueeze(amino_vector, 1)
for i in range(self.layer_cnn):
amino_vector = F.leaky_relu(self.conv_layers[i](amino_vector), self.alpha)
amino_vector = torch.squeeze(amino_vector, 1)
amino_vector = F.leaky_relu(self.W_prot(amino_vector), self.alpha)
return amino_vector
def mask_softmax(self, a, mask, dim=-1):
a_max = torch.max(a, dim, keepdim=True)[0]
a_exp = torch.exp(a - a_max)
a_exp = a_exp * mask
a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6)
return a_softmax
def bidirectional_attention_prediction(self, atoms_vector, atoms_mask, fps, amino_vector, amino_mask):
b = atoms_vector.shape[0]
for i in range(self.bidat_num):
A = torch.tanh(torch.matmul(torch.matmul(atoms_vector, self.U[i]), amino_vector.transpose(1, 2)))
A = A * torch.matmul(atoms_mask.view(b, -1, 1).float(), amino_mask.view(b, 1, -1).float())
atoms_trans = torch.matmul(A, torch.tanh(self.transform_p2c[i](amino_vector)))
amino_trans = torch.matmul(A.transpose(1, 2), torch.tanh(self.transform_c2p[i](atoms_vector)))
atoms_tmp = torch.cat([torch.tanh(self.bihidden_c[i](atoms_vector)), atoms_trans], dim=2)
amino_tmp = torch.cat([torch.tanh(self.bihidden_p[i](amino_vector)), amino_trans], dim=2)
atoms_att = self.mask_softmax(self.biatt_c[i](atoms_tmp).view(b, -1), atoms_mask.view(b, -1).float())
amino_att = self.mask_softmax(self.biatt_p[i](amino_tmp).view(b, -1), amino_mask.view(b, -1).float())
cf = torch.sum(atoms_vector * atoms_att.view(b, -1, 1), dim=1)
pf = torch.sum(amino_vector * amino_att.view(b, -1, 1), dim=1)
if i == 0:
cat_cf = cf
cat_pf = pf
else:
cat_cf = torch.cat([cat_cf.view(b, -1), cf.view(b, -1)], dim=1)
cat_pf = torch.cat([cat_pf.view(b, -1), pf.view(b, -1)], dim=1)
cf_final = torch.cat([self.comb_c(cat_cf).view(b, -1), fps.view(b, -1)], dim=1)
pf_final = self.comb_p(cat_pf)
cf_pf = F.leaky_relu(
torch.matmul(
cf_final.view(b, -1, 1), pf_final.view(b, 1, -1)
).view(b, -1), 0.1
)
return cf_pf
def forward(self, compound, protein):
atom, adj, fp = compound
atom, atom_lengths = atom
adj, _ = adj
fp, _ = fp
amino, amino_lengths = protein
atom_mask = torch.arange(atom.size(1), device=atom.device) >= atom_lengths.unsqueeze(1)
amino_mask = torch.arange(amino.size(1), device=amino.device) >= amino_lengths.unsqueeze(1)
atoms_vector = self.comp_gat(atom, atom_mask, adj)
amino_vector = self.prot_cnn(amino, amino_mask)
super_feature = F.leaky_relu(torch.matmul(fp.float(), self.fp0), 0.1)
super_feature = F.leaky_relu(torch.matmul(super_feature, self.fp1), 0.1)
prediction = self.bidirectional_attention_prediction(
atoms_vector, atom_mask, super_feature, amino_vector, amino_mask)
return prediction
class GATLayer(nn.Module):
def __init__(self, in_features, out_features, dropout=0.5, alpha=0.2, concat=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.dropout = dropout
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
def forward(self, h, adj):
Wh = torch.matmul(h, self.W)
a_input = self._prepare_attentional_mechanism_input(Wh)
e = F.leaky_relu(torch.matmul(a_input, self.a).squeeze(3), self.alpha)
zero_vec = -9e15 * torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=2)
# attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.bmm(attention, Wh)
return F.elu(h_prime) if self.concat else h_prime
def _prepare_attentional_mechanism_input(self, Wh):
b = Wh.size()[0]
N = Wh.size()[1]
Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1)
Wh_repeated_alternating = Wh.repeat_interleave(N, dim=0).view(b, N * N, self.out_features)
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2)
return all_combinations_matrix.view(b, N, N, 2 * self.out_features)
atom_dict = defaultdict(lambda: len(atom_dict))
bond_dict = defaultdict(lambda: len(bond_dict))
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))
edge_dict = defaultdict(lambda: len(edge_dict))
word_dict = defaultdict(lambda: len(word_dict))
def create_atoms(mol):
atoms = [a.GetSymbol() for a in mol.GetAtoms()]
for a in mol.GetAromaticAtoms():
i = a.GetIdx()
atoms[i] = (atoms[i], 'aromatic')
atoms = [atom_dict[a] for a in atoms]
return np.array(atoms)
def create_ijbonddict(mol):
i_jbond_dict = defaultdict(lambda: [])
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
bond = bond_dict[str(b.GetBondType())]
i_jbond_dict[i].append((j, bond))
i_jbond_dict[j].append((i, bond))
atoms_set = set(range(mol.GetNumAtoms()))
isolate_atoms = atoms_set - set(i_jbond_dict.keys())
bond = bond_dict['nan']
for a in isolate_atoms:
i_jbond_dict[a].append((a, bond))
return i_jbond_dict
def atom_features(atoms, i_jbond_dict, radius):
if (len(atoms) == 1) or (radius == 0):
fingerprints = [fingerprint_dict[a] for a in atoms]
else:
nodes = atoms
i_jedge_dict = i_jbond_dict
for _ in range(radius):
fingerprints = []
for i, j_edge in i_jedge_dict.items():
neighbors = [(nodes[j], edge) for j, edge in j_edge]
fingerprint = (nodes[i], tuple(sorted(neighbors)))
fingerprints.append(fingerprint_dict[fingerprint])
nodes = fingerprints
_i_jedge_dict = defaultdict(lambda: [])
for i, j_edge in i_jedge_dict.items():
for j, edge in j_edge:
both_side = tuple(sorted((nodes[i], nodes[j])))
edge = edge_dict[(both_side, edge)]
_i_jedge_dict[i].append((j, edge))
i_jedge_dict = _i_jedge_dict
return np.array(fingerprints)
def create_adjacency(mol):
adjacency = Chem.GetAdjacencyMatrix(mol)
adjacency = np.array(adjacency)
adjacency += np.eye(adjacency.shape[0], dtype=int)
return adjacency
def get_fingerprints(mol):
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024, useChirality=True)
return np.array(fp)
def split_sequence(sequence, ngram=3):
sequence = '-' + sequence + '='
words = [word_dict[sequence[i:i + ngram]]
for i in range(len(sequence) - ngram + 1)]
return np.array(words)
def drug_featurizer(smiles, radius=2):
from deepscreen.utils import get_logger
log = get_logger(__name__)
try:
mol = Chem.MolFromSmiles(smiles)
if not mol:
return None
mol = Chem.AddHs(mol)
atoms = create_atoms(mol)
i_jbond_dict = create_ijbonddict(mol)
compound = atom_features(atoms, i_jbond_dict, radius)
adjacency = create_adjacency(mol)
fp = get_fingerprints(mol)
return compound, adjacency, fp
except Exception as e:
log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}")
return None