Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from collections import defaultdict | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from rdkit import Chem | |
from rdkit.Chem import AllChem | |
class BACPI(nn.Module): | |
def __init__( | |
self, | |
n_atom, | |
n_amino, | |
comp_dim, | |
prot_dim, | |
gat_dim, | |
num_head, | |
dropout, | |
alpha, | |
window, | |
layer_cnn, | |
latent_dim, | |
): | |
super().__init__() | |
self.embedding_layer_atom = nn.Embedding(n_atom + 1, comp_dim) | |
self.embedding_layer_amino = nn.Embedding(n_amino + 1, prot_dim) | |
self.dropout = dropout | |
self.alpha = alpha | |
self.layer_cnn = layer_cnn | |
self.gat_layers = [GATLayer(comp_dim, gat_dim, dropout=dropout, alpha=alpha, concat=True) | |
for _ in range(num_head)] | |
for i, layer in enumerate(self.gat_layers): | |
self.add_module('gat_layer_{}'.format(i), layer) | |
self.gat_out = GATLayer(gat_dim * num_head, comp_dim, dropout=dropout, alpha=alpha, concat=False) | |
self.W_comp = nn.Linear(comp_dim, latent_dim) | |
self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2 * window + 1, | |
stride=1, padding=window) for _ in range(layer_cnn)]) | |
self.W_prot = nn.Linear(prot_dim, latent_dim) | |
self.fp0 = nn.Parameter(torch.empty(size=(1024, latent_dim))) | |
nn.init.xavier_uniform_(self.fp0, gain=1.414) | |
self.fp1 = nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) | |
nn.init.xavier_uniform_(self.fp1, gain=1.414) | |
self.bidat_num = 4 | |
self.U = nn.ParameterList([ | |
nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) for _ in range(self.bidat_num) | |
]) | |
for i in range(self.bidat_num): | |
nn.init.xavier_uniform_(self.U[i], gain=1.414) | |
self.transform_c2p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
self.transform_p2c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
self.bihidden_c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
self.bihidden_p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
self.biatt_c = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) | |
self.biatt_p = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) | |
self.comb_c = nn.Linear(latent_dim * self.bidat_num, latent_dim) | |
self.comb_p = nn.Linear(latent_dim * self.bidat_num, latent_dim) | |
def comp_gat(self, atoms, atoms_mask, adj): | |
atoms_vector = self.embedding_layer_atom(atoms) | |
atoms_multi_head = torch.cat([gat(atoms_vector, adj) for gat in self.gat_layers], dim=2) | |
atoms_vector = F.elu(self.gat_out(atoms_multi_head, adj)) | |
atoms_vector = F.leaky_relu(self.W_comp(atoms_vector), self.alpha) | |
return atoms_vector | |
def prot_cnn(self, amino, amino_mask): | |
amino_vector = self.embedding_layer_amino(amino) | |
amino_vector = torch.unsqueeze(amino_vector, 1) | |
for i in range(self.layer_cnn): | |
amino_vector = F.leaky_relu(self.conv_layers[i](amino_vector), self.alpha) | |
amino_vector = torch.squeeze(amino_vector, 1) | |
amino_vector = F.leaky_relu(self.W_prot(amino_vector), self.alpha) | |
return amino_vector | |
def mask_softmax(self, a, mask, dim=-1): | |
a_max = torch.max(a, dim, keepdim=True)[0] | |
a_exp = torch.exp(a - a_max) | |
a_exp = a_exp * mask | |
a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6) | |
return a_softmax | |
def bidirectional_attention_prediction(self, atoms_vector, atoms_mask, fps, amino_vector, amino_mask): | |
b = atoms_vector.shape[0] | |
for i in range(self.bidat_num): | |
A = torch.tanh(torch.matmul(torch.matmul(atoms_vector, self.U[i]), amino_vector.transpose(1, 2))) | |
A = A * torch.matmul(atoms_mask.view(b, -1, 1).float(), amino_mask.view(b, 1, -1).float()) | |
atoms_trans = torch.matmul(A, torch.tanh(self.transform_p2c[i](amino_vector))) | |
amino_trans = torch.matmul(A.transpose(1, 2), torch.tanh(self.transform_c2p[i](atoms_vector))) | |
atoms_tmp = torch.cat([torch.tanh(self.bihidden_c[i](atoms_vector)), atoms_trans], dim=2) | |
amino_tmp = torch.cat([torch.tanh(self.bihidden_p[i](amino_vector)), amino_trans], dim=2) | |
atoms_att = self.mask_softmax(self.biatt_c[i](atoms_tmp).view(b, -1), atoms_mask.view(b, -1).float()) | |
amino_att = self.mask_softmax(self.biatt_p[i](amino_tmp).view(b, -1), amino_mask.view(b, -1).float()) | |
cf = torch.sum(atoms_vector * atoms_att.view(b, -1, 1), dim=1) | |
pf = torch.sum(amino_vector * amino_att.view(b, -1, 1), dim=1) | |
if i == 0: | |
cat_cf = cf | |
cat_pf = pf | |
else: | |
cat_cf = torch.cat([cat_cf.view(b, -1), cf.view(b, -1)], dim=1) | |
cat_pf = torch.cat([cat_pf.view(b, -1), pf.view(b, -1)], dim=1) | |
cf_final = torch.cat([self.comb_c(cat_cf).view(b, -1), fps.view(b, -1)], dim=1) | |
pf_final = self.comb_p(cat_pf) | |
cf_pf = F.leaky_relu( | |
torch.matmul( | |
cf_final.view(b, -1, 1), pf_final.view(b, 1, -1) | |
).view(b, -1), 0.1 | |
) | |
return cf_pf | |
def forward(self, compound, protein): | |
atom, adj, fp = compound | |
atom, atom_lengths = atom | |
adj, _ = adj | |
fp, _ = fp | |
amino, amino_lengths = protein | |
atom_mask = torch.arange(atom.size(1), device=atom.device) >= atom_lengths.unsqueeze(1) | |
amino_mask = torch.arange(amino.size(1), device=amino.device) >= amino_lengths.unsqueeze(1) | |
atoms_vector = self.comp_gat(atom, atom_mask, adj) | |
amino_vector = self.prot_cnn(amino, amino_mask) | |
super_feature = F.leaky_relu(torch.matmul(fp.float(), self.fp0), 0.1) | |
super_feature = F.leaky_relu(torch.matmul(super_feature, self.fp1), 0.1) | |
prediction = self.bidirectional_attention_prediction( | |
atoms_vector, atom_mask, super_feature, amino_vector, amino_mask) | |
return prediction | |
class GATLayer(nn.Module): | |
def __init__(self, in_features, out_features, dropout=0.5, alpha=0.2, concat=True): | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.dropout = dropout | |
self.alpha = alpha | |
self.concat = concat | |
self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) | |
nn.init.xavier_uniform_(self.W.data, gain=1.414) | |
self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) | |
nn.init.xavier_uniform_(self.a.data, gain=1.414) | |
def forward(self, h, adj): | |
Wh = torch.matmul(h, self.W) | |
a_input = self._prepare_attentional_mechanism_input(Wh) | |
e = F.leaky_relu(torch.matmul(a_input, self.a).squeeze(3), self.alpha) | |
zero_vec = -9e15 * torch.ones_like(e) | |
attention = torch.where(adj > 0, e, zero_vec) | |
attention = F.softmax(attention, dim=2) | |
# attention = F.dropout(attention, self.dropout, training=self.training) | |
h_prime = torch.bmm(attention, Wh) | |
return F.elu(h_prime) if self.concat else h_prime | |
def _prepare_attentional_mechanism_input(self, Wh): | |
b = Wh.size()[0] | |
N = Wh.size()[1] | |
Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1) | |
Wh_repeated_alternating = Wh.repeat_interleave(N, dim=0).view(b, N * N, self.out_features) | |
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2) | |
return all_combinations_matrix.view(b, N, N, 2 * self.out_features) | |
atom_dict = defaultdict(lambda: len(atom_dict)) | |
bond_dict = defaultdict(lambda: len(bond_dict)) | |
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) | |
edge_dict = defaultdict(lambda: len(edge_dict)) | |
word_dict = defaultdict(lambda: len(word_dict)) | |
def create_atoms(mol): | |
atoms = [a.GetSymbol() for a in mol.GetAtoms()] | |
for a in mol.GetAromaticAtoms(): | |
i = a.GetIdx() | |
atoms[i] = (atoms[i], 'aromatic') | |
atoms = [atom_dict[a] for a in atoms] | |
return np.array(atoms) | |
def create_ijbonddict(mol): | |
i_jbond_dict = defaultdict(lambda: []) | |
for b in mol.GetBonds(): | |
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() | |
bond = bond_dict[str(b.GetBondType())] | |
i_jbond_dict[i].append((j, bond)) | |
i_jbond_dict[j].append((i, bond)) | |
atoms_set = set(range(mol.GetNumAtoms())) | |
isolate_atoms = atoms_set - set(i_jbond_dict.keys()) | |
bond = bond_dict['nan'] | |
for a in isolate_atoms: | |
i_jbond_dict[a].append((a, bond)) | |
return i_jbond_dict | |
def atom_features(atoms, i_jbond_dict, radius): | |
if (len(atoms) == 1) or (radius == 0): | |
fingerprints = [fingerprint_dict[a] for a in atoms] | |
else: | |
nodes = atoms | |
i_jedge_dict = i_jbond_dict | |
for _ in range(radius): | |
fingerprints = [] | |
for i, j_edge in i_jedge_dict.items(): | |
neighbors = [(nodes[j], edge) for j, edge in j_edge] | |
fingerprint = (nodes[i], tuple(sorted(neighbors))) | |
fingerprints.append(fingerprint_dict[fingerprint]) | |
nodes = fingerprints | |
_i_jedge_dict = defaultdict(lambda: []) | |
for i, j_edge in i_jedge_dict.items(): | |
for j, edge in j_edge: | |
both_side = tuple(sorted((nodes[i], nodes[j]))) | |
edge = edge_dict[(both_side, edge)] | |
_i_jedge_dict[i].append((j, edge)) | |
i_jedge_dict = _i_jedge_dict | |
return np.array(fingerprints) | |
def create_adjacency(mol): | |
adjacency = Chem.GetAdjacencyMatrix(mol) | |
adjacency = np.array(adjacency) | |
adjacency += np.eye(adjacency.shape[0], dtype=int) | |
return adjacency | |
def get_fingerprints(mol): | |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024, useChirality=True) | |
return np.array(fp) | |
def split_sequence(sequence, ngram=3): | |
sequence = '-' + sequence + '=' | |
words = [word_dict[sequence[i:i + ngram]] | |
for i in range(len(sequence) - ngram + 1)] | |
return np.array(words) | |
def drug_featurizer(smiles, radius=2): | |
from deepscreen.utils import get_logger | |
log = get_logger(__name__) | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
if not mol: | |
return None | |
mol = Chem.AddHs(mol) | |
atoms = create_atoms(mol) | |
i_jbond_dict = create_ijbonddict(mol) | |
compound = atom_features(atoms, i_jbond_dict, radius) | |
adjacency = create_adjacency(mol) | |
fp = get_fingerprints(mol) | |
return compound, adjacency, fp | |
except Exception as e: | |
log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}") | |
return None | |