from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from rdkit import Chem from rdkit.Chem import AllChem class BACPI(nn.Module): def __init__( self, n_atom, n_amino, comp_dim, prot_dim, gat_dim, num_head, dropout, alpha, window, layer_cnn, latent_dim, ): super().__init__() self.embedding_layer_atom = nn.Embedding(n_atom + 1, comp_dim) self.embedding_layer_amino = nn.Embedding(n_amino + 1, prot_dim) self.dropout = dropout self.alpha = alpha self.layer_cnn = layer_cnn self.gat_layers = [GATLayer(comp_dim, gat_dim, dropout=dropout, alpha=alpha, concat=True) for _ in range(num_head)] for i, layer in enumerate(self.gat_layers): self.add_module('gat_layer_{}'.format(i), layer) self.gat_out = GATLayer(gat_dim * num_head, comp_dim, dropout=dropout, alpha=alpha, concat=False) self.W_comp = nn.Linear(comp_dim, latent_dim) self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2 * window + 1, stride=1, padding=window) for _ in range(layer_cnn)]) self.W_prot = nn.Linear(prot_dim, latent_dim) self.fp0 = nn.Parameter(torch.empty(size=(1024, latent_dim))) nn.init.xavier_uniform_(self.fp0, gain=1.414) self.fp1 = nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) nn.init.xavier_uniform_(self.fp1, gain=1.414) self.bidat_num = 4 self.U = nn.ParameterList([ nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) for _ in range(self.bidat_num) ]) for i in range(self.bidat_num): nn.init.xavier_uniform_(self.U[i], gain=1.414) self.transform_c2p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) self.transform_p2c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) self.bihidden_c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) self.bihidden_p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) self.biatt_c = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) self.biatt_p = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) self.comb_c = nn.Linear(latent_dim * self.bidat_num, latent_dim) self.comb_p = nn.Linear(latent_dim * self.bidat_num, latent_dim) def comp_gat(self, atoms, atoms_mask, adj): atoms_vector = self.embedding_layer_atom(atoms) atoms_multi_head = torch.cat([gat(atoms_vector, adj) for gat in self.gat_layers], dim=2) atoms_vector = F.elu(self.gat_out(atoms_multi_head, adj)) atoms_vector = F.leaky_relu(self.W_comp(atoms_vector), self.alpha) return atoms_vector def prot_cnn(self, amino, amino_mask): amino_vector = self.embedding_layer_amino(amino) amino_vector = torch.unsqueeze(amino_vector, 1) for i in range(self.layer_cnn): amino_vector = F.leaky_relu(self.conv_layers[i](amino_vector), self.alpha) amino_vector = torch.squeeze(amino_vector, 1) amino_vector = F.leaky_relu(self.W_prot(amino_vector), self.alpha) return amino_vector def mask_softmax(self, a, mask, dim=-1): a_max = torch.max(a, dim, keepdim=True)[0] a_exp = torch.exp(a - a_max) a_exp = a_exp * mask a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6) return a_softmax def bidirectional_attention_prediction(self, atoms_vector, atoms_mask, fps, amino_vector, amino_mask): b = atoms_vector.shape[0] for i in range(self.bidat_num): A = torch.tanh(torch.matmul(torch.matmul(atoms_vector, self.U[i]), amino_vector.transpose(1, 2))) A = A * torch.matmul(atoms_mask.view(b, -1, 1).float(), amino_mask.view(b, 1, -1).float()) atoms_trans = torch.matmul(A, torch.tanh(self.transform_p2c[i](amino_vector))) amino_trans = torch.matmul(A.transpose(1, 2), torch.tanh(self.transform_c2p[i](atoms_vector))) atoms_tmp = torch.cat([torch.tanh(self.bihidden_c[i](atoms_vector)), atoms_trans], dim=2) amino_tmp = torch.cat([torch.tanh(self.bihidden_p[i](amino_vector)), amino_trans], dim=2) atoms_att = self.mask_softmax(self.biatt_c[i](atoms_tmp).view(b, -1), atoms_mask.view(b, -1).float()) amino_att = self.mask_softmax(self.biatt_p[i](amino_tmp).view(b, -1), amino_mask.view(b, -1).float()) cf = torch.sum(atoms_vector * atoms_att.view(b, -1, 1), dim=1) pf = torch.sum(amino_vector * amino_att.view(b, -1, 1), dim=1) if i == 0: cat_cf = cf cat_pf = pf else: cat_cf = torch.cat([cat_cf.view(b, -1), cf.view(b, -1)], dim=1) cat_pf = torch.cat([cat_pf.view(b, -1), pf.view(b, -1)], dim=1) cf_final = torch.cat([self.comb_c(cat_cf).view(b, -1), fps.view(b, -1)], dim=1) pf_final = self.comb_p(cat_pf) cf_pf = F.leaky_relu( torch.matmul( cf_final.view(b, -1, 1), pf_final.view(b, 1, -1) ).view(b, -1), 0.1 ) return cf_pf def forward(self, compound, protein): atom, adj, fp = compound atom, atom_lengths = atom adj, _ = adj fp, _ = fp amino, amino_lengths = protein atom_mask = torch.arange(atom.size(1), device=atom.device) >= atom_lengths.unsqueeze(1) amino_mask = torch.arange(amino.size(1), device=amino.device) >= amino_lengths.unsqueeze(1) atoms_vector = self.comp_gat(atom, atom_mask, adj) amino_vector = self.prot_cnn(amino, amino_mask) super_feature = F.leaky_relu(torch.matmul(fp.float(), self.fp0), 0.1) super_feature = F.leaky_relu(torch.matmul(super_feature, self.fp1), 0.1) prediction = self.bidirectional_attention_prediction( atoms_vector, atom_mask, super_feature, amino_vector, amino_mask) return prediction class GATLayer(nn.Module): def __init__(self, in_features, out_features, dropout=0.5, alpha=0.2, concat=True): super().__init__() self.in_features = in_features self.out_features = out_features self.dropout = dropout self.alpha = alpha self.concat = concat self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) nn.init.xavier_uniform_(self.W.data, gain=1.414) self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) nn.init.xavier_uniform_(self.a.data, gain=1.414) def forward(self, h, adj): Wh = torch.matmul(h, self.W) a_input = self._prepare_attentional_mechanism_input(Wh) e = F.leaky_relu(torch.matmul(a_input, self.a).squeeze(3), self.alpha) zero_vec = -9e15 * torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) attention = F.softmax(attention, dim=2) # attention = F.dropout(attention, self.dropout, training=self.training) h_prime = torch.bmm(attention, Wh) return F.elu(h_prime) if self.concat else h_prime def _prepare_attentional_mechanism_input(self, Wh): b = Wh.size()[0] N = Wh.size()[1] Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1) Wh_repeated_alternating = Wh.repeat_interleave(N, dim=0).view(b, N * N, self.out_features) all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2) return all_combinations_matrix.view(b, N, N, 2 * self.out_features) atom_dict = defaultdict(lambda: len(atom_dict)) bond_dict = defaultdict(lambda: len(bond_dict)) fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) edge_dict = defaultdict(lambda: len(edge_dict)) word_dict = defaultdict(lambda: len(word_dict)) def create_atoms(mol): atoms = [a.GetSymbol() for a in mol.GetAtoms()] for a in mol.GetAromaticAtoms(): i = a.GetIdx() atoms[i] = (atoms[i], 'aromatic') atoms = [atom_dict[a] for a in atoms] return np.array(atoms) def create_ijbonddict(mol): i_jbond_dict = defaultdict(lambda: []) for b in mol.GetBonds(): i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() bond = bond_dict[str(b.GetBondType())] i_jbond_dict[i].append((j, bond)) i_jbond_dict[j].append((i, bond)) atoms_set = set(range(mol.GetNumAtoms())) isolate_atoms = atoms_set - set(i_jbond_dict.keys()) bond = bond_dict['nan'] for a in isolate_atoms: i_jbond_dict[a].append((a, bond)) return i_jbond_dict def atom_features(atoms, i_jbond_dict, radius): if (len(atoms) == 1) or (radius == 0): fingerprints = [fingerprint_dict[a] for a in atoms] else: nodes = atoms i_jedge_dict = i_jbond_dict for _ in range(radius): fingerprints = [] for i, j_edge in i_jedge_dict.items(): neighbors = [(nodes[j], edge) for j, edge in j_edge] fingerprint = (nodes[i], tuple(sorted(neighbors))) fingerprints.append(fingerprint_dict[fingerprint]) nodes = fingerprints _i_jedge_dict = defaultdict(lambda: []) for i, j_edge in i_jedge_dict.items(): for j, edge in j_edge: both_side = tuple(sorted((nodes[i], nodes[j]))) edge = edge_dict[(both_side, edge)] _i_jedge_dict[i].append((j, edge)) i_jedge_dict = _i_jedge_dict return np.array(fingerprints) def create_adjacency(mol): adjacency = Chem.GetAdjacencyMatrix(mol) adjacency = np.array(adjacency) adjacency += np.eye(adjacency.shape[0], dtype=int) return adjacency def get_fingerprints(mol): fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024, useChirality=True) return np.array(fp) def split_sequence(sequence, ngram=3): sequence = '-' + sequence + '=' words = [word_dict[sequence[i:i + ngram]] for i in range(len(sequence) - ngram + 1)] return np.array(words) def drug_featurizer(smiles, radius=2): from deepscreen.utils import get_logger log = get_logger(__name__) try: mol = Chem.MolFromSmiles(smiles) if not mol: return None mol = Chem.AddHs(mol) atoms = create_atoms(mol) i_jbond_dict = create_ijbonddict(mol) compound = atom_features(atoms, i_jbond_dict, radius) adjacency = create_adjacency(mol) fp = get_fingerprints(mol) return compound, adjacency, fp except Exception as e: log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}") return None