import torch from torch import nn class ArkDTA(nn.Module): def __init__(self, args): super(Net, self).__init__() self.layer = nn.ModuleDict() analysis_mode = args.analysis_mode h = args.arkdta_hidden_dim d = args.hp_dropout_rate esm = args.arkdta_esm_model esm_freeze = args.arkdta_esm_freeze E = args.arkdta_ecfpvec_dim L = args.arkdta_sab_depth A = args.arkdta_attention_option K = args.arkdta_num_heads assert 'ARKMAB' in args.arkdta_residue_addon self.layer['prot_encoder'] = FastaESM(h, esm, esm_freeze, analysis_mode) self.layer['comp_encoder'] = EcfpConverter(h, L, E, analysis_mode) self.layer['intg_arkmab'] = load_residue_addon(args) self.layer['intg_pooling'] = load_complex_decoder(args) self.layer['ba_predictor'] = AffinityMLP(h) self.layer['dt_predictor'] = InteractionMLP(h) def load_auxiliary_materials(self, **kwargs): return_batch = kwargs['return_batch'] b = kwargs['atomresi_adj'].size(0) x, y, z = kwargs['encoder_attention'].size() logits0 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, :-1].sum(2).unsqueeze( 2) # actual compsub logits1 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, -1].unsqueeze(2) # inactive site return_batch['task/es_pred'] = torch.cat([logits1, logits0], 2) return_batch['task/es_true'] = (kwargs['atomresi_adj'].sum(1) > 0.).long().squeeze(1) return_batch['mask/es_resi'] = (kwargs['atomresi_masks'].sum(1) > 0.).float().squeeze(1) return return_batch def forward(self, batch): return_batch = dict() residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2] ecfp_words, ecfp_masks = batch[3], batch[4] atomresi_adj, atomresi_masks = batch[5], batch[6] bav, dti, cids = batch[7], batch[8], batch[-1] # Protein Encoder Module residue_features = self.layer['prot_encoder'](X=residue_features, fastas=residue_fastas, masks=residue_masks) residue_masks = residue_features[1] residue_temps = residue_features[2] protein_features = residue_features[3] residue_features = residue_features[0] return_batch['temp/lm_related'] = residue_temps * 0. # Ligand Encoder Module cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words, ecfp_masks=ecfp_masks) cstruct_masks = cstruct_features[1] cstruct_features = cstruct_features[0] # Protein-Ligand Integration Module (ARK-MAB) residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks, ligelem_features=cstruct_features, ligelem_masks=cstruct_masks) residue_features, residue_masks, attention_weights = residue_results del residue_results; torch.cuda.empty_cache() # Protein-Ligand Integration Module (Pooling Layer) complex_results = self.layer['intg_pooling'](residue_features=residue_features, residue_masks=residue_masks, attention_weights=attention_weights, protein_features=protein_features) binding_complex, _, _, _ = complex_results del complex_results; torch.cuda.empty_cache() # Drug-Target Outcome Predictor bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex) dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex) return_batch['task/ba_pred'] = bav_predicted.view(-1) return_batch['task/dt_pred'] = dti_predicted.view(-1) return_batch['task/ba_true'] = bav.view(-1) return_batch['task/dt_true'] = dti.view(-1) return_batch['meta/cid'] = cids # Additional Materials for Calculating Auxiliary Loss return_batch = self.load_auxiliary_materials(return_batch=return_batch, atomresi_adj=atomresi_adj, atomresi_masks=atomresi_masks, encoder_attention=attention_weights) return return_batch @torch.no_grad() def infer(self, batch): return_batch = dict() residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2] ecfp_words, ecfp_masks = batch[3], batch[4] bav, dti, cids = batch[7], batch[8], batch[-1] # Protein Encoder Module residue_features = self.layer['prot_encoder'](X=residue_features, fastas=residue_fastas, masks=residue_masks) residue_masks = residue_features[1] residue_temps = residue_features[2] protein_features = residue_features[3] residue_features = residue_features[0] return_batch['temp/lm_related'] = residue_temps * 0. # Ligand Encoder Module cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words, ecfp_masks=ecfp_masks) cstruct_masks = cstruct_features[1] cstruct_features = cstruct_features[0] # Protein-Ligand Integration Module (ARK-MAB) residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks, ligelem_features=cstruct_features, ligelem_masks=cstruct_masks) residue_features, residue_masks, attention_weights = residue_results del residue_results; torch.cuda.empty_cache() # Protein-Ligand Integration Module (Pooling Layer) complex_results = self.layer['intg_pooling'](residue_features=residue_features, residue_masks=residue_masks, attention_weights=attention_weights, protein_features=protein_features) binding_complex, _, _, _ = complex_results del complex_results; torch.cuda.empty_cache() # Drug-Target Outcome Predictor bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex) dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex) return_batch['task/ba_pred'] = bav_predicted.view(-1) return_batch['task/dt_pred'] = dti_predicted.view(-1) return_batch['task/ba_true'] = bav.view(-1) return_batch['task/dt_true'] = dti.view(-1) return_batch['meta/cid'] = cids return return_batch class GraphDenseSequential(nn.Sequential): def __init__(self, *args): super(GraphDenseSequential, self).__init__(*args) def forward(self, X, adj, mask): for module in self._modules.values(): try: X = module(X, adj, mask) except BaseException: X = module(X) return X class MaskedGlobalPooling(nn.Module): def __init__(self, pooling='max'): super(MaskedGlobalPooling, self).__init__() self.pooling = pooling def forward(self, x, adj, masks): if x.dim() == 2: x = x.unsqueeze(0) # print(x, adj, masks) masks = masks.unsqueeze(2).repeat(1, 1, x.size(2)) if self.pooling == 'max': x[masks == 0] = -99999.99999 x = x.max(1)[0] elif self.pooling == 'add': x = x.sum(1) else: print('Not Implemented') return x class MaskedMean(nn.Module): def __init__(self): super(MaskedMean, self).__init__() def forward(self, X, m): if isinstance(m, torch.Tensor): X = X * m.unsqueeze(2) return X.mean(1) class MaskedMax(nn.Module): def __init__(self): super(MaskedMax, self).__init__() def forward(self, X, m): if isinstance(m, torch.Tensor): X = X * m.unsqueeze(2) return torch.max(X, 1)[0] class MaskedSum(nn.Module): def __init__(self): super(MaskedSum, self).__init__() def forward(self, X, m): if isinstance(m, torch.Tensor): X = X * m.unsqueeze(2) return X.sum(1) class MaskedScaledAverage(nn.Module): def __init__(self): super(MaskedScaledAverage, self).__init__() def forward(self, X, m): if isinstance(m, torch.Tensor): X = X * m.unsqueeze(2) return X.sum(1) / (m.sum(1) ** 0.5).unsqueeze(1) class Decoder(nn.Module): def __init__(self, analysis_mode): super(Decoder, self).__init__() self.output_representations = [] self.query_representations = [] self.kvpair_representations = [] self.attention_weights = [] if analysis_mode: self.register_forward_hook(store_decoder_representations) def show(self): print("Number of Saved Numpy Arrays: ", len(self.representations)) for i, representation in enumerate(self.representations): print(f"Shape of {i}th Numpy Array: ", representation.shape) return self.representations def flush(self): del self.representations self.representations = [] def release_qk(self): return None def forward(self, **kwargs): return kwargs['X'], kwargs['X'], kwargs['residue_features'], None class DecoderPMA_Residue(Decoder): def __init__(self, h: int, num_heads: int, num_seeds: int, attn_option: str, analysis_mode: bool): super(DecoderPMA_Residue, self).__init__(analysis_mode) # Aggregate the Residues into Residue Regions pma_args = (h, num_seeds, num_heads, RFF(h), attn_option, False, analysis_mode, False) self.decoder = PoolingMultiheadAttention(*pma_args) # Model Region-Region Interaction through Set Attention sab_depth = 0 if num_seeds < 4 else int((num_seeds // 2) ** 0.5) sab_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, True) self.pairwise = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)]) # Concat, then reduce into h-dimensional Set Representation self.aggregate = nn.Linear(h * num_seeds, h) self.apply(initialization) def forward(self, **kwargs): residue_features = kwargs['residue_features'] residue_masks = kwargs['residue_masks'] output, attention = self.decoder(residue_features, residue_masks) for sab in self.pairwise: output, _ = sab(output) b, n, d = output.shape output = self.aggregate(output.view(b, n * d)) return output, None, residue_features, attention class AffinityMLP(nn.Module): def __init__(self, h: int): super(AffinityMLP, self).__init__() self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1)) self.apply(initialization) def forward(self, **kwargs): ''' X: batch size x 1 x H ''' X = kwargs['binding_complex'] X = X.squeeze(1) if X.dim() == 3 else X yhat = self.mlp(X) return yhat class InteractionMLP(nn.Module): def __init__(self, h: int): super(InteractionMLP, self).__init__() self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1), nn.Sigmoid()) self.apply(initialization) def forward(self, **kwargs): ''' X: batch size x 1 x H ''' X = kwargs['binding_complex'] X = X.squeeze(1) if X.dim() == 3 else X yhat = self.mlp(X) return yhat class LigelemEncoder(nn.Module): def __init__(self): super(LigelemEncoder, self).__init__() self.representations = [] def show(self): print("Number of Saved Numpy Arrays: ", len(self.representations)) for i, representation in enumerate(self.representations): print(f"Shape of {i}th Numpy Array: ", representation.shape) return self.representations def flush(self): del self.representations self.representations = [] class EcfpConverter(LigelemEncoder): def __init__(self, h: int, sab_depth: int, ecfp_dim: int, analysis_mode: bool): super(EcfpConverter, self).__init__() K = 4 # number of attention heads self.ecfp_embeddings = nn.Embedding(ecfp_dim + 1, h, padding_idx=ecfp_dim) self.encoder = nn.ModuleList([]) sab_args = (h, K, RFF(h), 'general_dot', False, analysis_mode, True) self.encoder = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)]) self.representations = [] if analysis_mode: self.register_forward_hook(store_elemwise_representations) self.apply(initialization) def forward(self, **kwargs): ''' X : (b x d) ''' ecfp_words = kwargs['ecfp_words'] ecfp_masks = kwargs['ecfp_masks'] ecfp_words = self.ecfp_embeddings(ecfp_words) for sab in self.encoder: ecfp_words, _ = sab(ecfp_words, ecfp_masks) return [ecfp_words, ecfp_masks] class ResidueAddOn(nn.Module): def __init__(self): super(ResidueAddOn, self).__init__() self.representations = [] def show(self): print("Number of Saved Numpy Arrays: ", len(self.representations)) for i, representation in enumerate(self.representations): print(f"Shape of {i}th Numpy Array: ", representation.shape) return self.representations def flush(self): del self.representations self.representations = [] def forward(self, **kwargs): X, Xm = kwargs['X'], kwargs['Xm'] return X, Xm class ARKMAB(ResidueAddOn): def __init__(self, h: int, num_heads: int, attn_option: str, analysis_mode: bool): super(ARKMAB, self).__init__() pmx_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, False) self.pmx = PoolingMultiheadCrossAttention(*pmx_args) self.inactive = nn.Parameter(torch.randn(1, 1, h)) self.fillmask = nn.Parameter(torch.ones(1, 1), requires_grad=False) self.representations = [] if analysis_mode: pass self.apply(initialization) def forward(self, **kwargs): ''' X: batch size x residues x H Xm: batch size x residues x H Y: batch size x ecfpsubs x H Ym: batch size x ecfpsubs x H ''' X, Xm = kwargs['residue_features'], kwargs['residue_masks'] Y, Ym = kwargs['ligelem_features'], kwargs['ligelem_masks'] pseudo_substructure = self.inactive.repeat(X.size(0), 1, 1) pseudo_masks = self.fillmask.repeat(X.size(0), 1) Y = torch.cat([Y, pseudo_substructure], 1) Ym = torch.cat([Ym, pseudo_masks], 1) X, attention = self.pmx(Y=Y, Ym=Ym, X=X, Xm=Xm) return X, Xm, attention class ResidueEncoder(nn.Module): def __init__(self): super(ResidueEncoder, self).__init__() self.representations = [] def show(self): print("Number of Saved Numpy Arrays: ", len(self.representations)) for i, representation in enumerate(self.representations): print(f"Shape of {i}th Numpy Array: ", representation.shape) return self.representations def flush(self): del self.representations self.representations = [] class AminoAcidSeqCNN(ResidueEncoder): def __init__(self, h: int, d: float, cnn_depth: int, kernel_size: int, analysis_mode: bool): super(AminoAcidSeqCNN, self).__init__() self.encoder = nn.ModuleList([nn.Sequential(nn.Linear(21, h), # Warning nn.Dropout(d), nn.LeakyReLU(), nn.Linear(h, h))]) for _ in range(cnn_depth): self.encoder.append(nn.Conv1d(h, h, kernel_size, 1, (kernel_size - 1) // 2)) self.representations = [] if analysis_mode: self.register_forward_hook(store_representations) self.apply(initialization) def forward(self, **kwargs): X = kwargs['aaseqs'] for i, module in enumerate(self.encoder): if i == 1: X = X.transpose(1, 2) X = module(X) X = X.transpose(1, 2) return X class FastaESM(ResidueEncoder): def __init__(self, h: int, esm_model: str, esm_freeze: bool, analysis_mode: bool): super(FastaESM, self).__init__() self.esm_version = 2 if 'esm2' in esm_model else 1 if esm_model == 'esm1b_t33_650M_UR505': self.esm, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() self.layer_idx, self.emb_dim = 33, 1024 elif esm_model == 'esm1_t12_85M_UR505': self.esm, alphabet = esm.pretrained.esm1_t12_85M_UR50S() self.layer_idx, self.emb_dim = 12, 768 elif esm_model == 'esm2_t6_8M_UR50D': self.esm, alphabet = esm.pretrained.esm2_t6_8M_UR50D() self.layer_idx, self.emb_dim = 6, 320 elif esm_model == 'esm2_t12_35M_UR50D': self.esm, alphabet = esm.pretrained.esm2_t12_35M_UR50D() self.layer_idx, self.emb_dim = 12, 480 elif esm_model == 'esm2_t30_150M_UR50D': self.esm, alphabet = esm.pretrained.esm2_t30_150M_UR50D() self.layer_idx, self.emb_dim = 30, 640 else: raise self.batch_converter = alphabet.get_batch_converter() if esm_freeze == 'True': for p in self.esm.parameters(): p.requires_grad = False assert h == self.emb_dim, f"The hidden dimension should be set to {self.emb_dim}, not {h}" self.representations = [] if analysis_mode: self.register_forward_hook(store_elemwise_representations) def esm1_pooling(self, embeddings, masks): return embeddings[:, 1:, :].sum(1) / masks[:, 1:].sum(1).view(-1, 1) def esm2_pooling(self, embeddings, masks): return embeddings[:, 1:-1, :].sum(1) / masks[:, 1:-1].sum(1).view(-1, 1) def forward(self, **kwargs): fastas = kwargs['fastas'] _, _, tokenized = self.batch_converter(fastas) tokenized = tokenized.cuda() if self.esm_version == 2: masks = torch.where(tokenized > 1, 1, 0).float() else: masks = torch.where((tokenized > 1) & (tokenized != 32), 1, 0).float() embeddings = self.esm(tokenized, repr_layers=[self.layer_idx], return_contacts=True) logits = embeddings["logits"].sum() contacts = embeddings["contacts"].sum() attentions = embeddings["attentions"].sum() embeddings = embeddings["representations"][self.layer_idx] assert masks.size(0) == embeddings.size( 0), f"Batch sizes of masks {masks.size(0)} and {embeddings.size(0)} do not match." assert masks.size(1) == embeddings.size( 1), f"Lengths of masks {masks.size(1)} and {embeddings.size(1)} do not match." if self.esm_version == 2: return [embeddings[:, 1:-1, :], masks[:, 1:-1], logits + contacts + attentions, self.esm2_pooling(embeddings, masks)] else: return [embeddings[:, 1:, :], masks[:, 1:], logits + contacts + attentions, self.esm1_pooling(embeddings, masks)] class DotProduct(nn.Module): def __init__(self): super().__init__() def forward(self, queries, keys): return torch.bmm(queries, keys.transpose(1, 2)) class ScaledDotProduct(nn.Module): def __init__(self): super().__init__() def forward(self, queries, keys): return torch.bmm(queries, keys.transpose(1, 2)) / (queries.size(2) ** 0.5) class GeneralDotProduct(nn.Module): def __init__(self, hidden_dim): super().__init__() self.W = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) torch.nn.init.orthogonal_(self.W) def forward(self, queries, keys): return torch.bmm(queries @ self.W, keys.transpose(1, 2)) class ConcatDotProduct(nn.Module): def __init__(self, hidden_dim): super().__init__() raise def forward(self, queries, keys): return class Additive(nn.Module): def __init__(self, hidden_dim): super().__init__() self.U = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) self.T = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) self.b = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1)) self.W = nn.Sequential(nn.Tanh(), nn.Linear(hidden_dim, 1)) torch.nn.init.orthogonal_(self.U) torch.nn.init.orthogonal_(self.T) def forward(self, queries, keys): return self.W(queries.unsqueeze(1) @ self.U + keys.unsqueeze(2) @ self.T + self.b).squeeze(-1).transpose(1, 2) class Attention(nn.Module): def __init__(self, similarity, hidden_dim=1024, store_qk=False): super().__init__() self.softmax = nn.Softmax(dim=2) self.attention_maps = [] self.store_qk = store_qk self.query_vectors, self.key_vectors = None, None assert similarity in ['dot', 'scaled_dot', 'general_dot', 'concat_dot', 'additive'] if similarity == 'dot': self.similarity = DotProduct() elif similarity == 'scaled_dot': self.similarity = ScaledDotProduct() elif similarity == 'general_dot': self.similarity = GeneralDotProduct(hidden_dim) elif similarity == 'concat_dot': self.similarity = ConcatDotProduct(hidden_dim) elif similarity == 'additive': self.similarity = Additive(hidden_dim) else: raise def release_qk(self): Q, K = self.query_vectors, self.key_vectors self.query_vectors, self.key_vectors = None, None torch.cuda.empty_cache() return Q, K def forward(self, queries, keys, qmasks=None, kmasks=None): if self.store_qk: self.query_vectors = queries self.key_vectors = keys if torch.is_tensor(qmasks) and not torch.is_tensor(kmasks): dim0, dim1 = qmasks.size(0), keys.size(1) kmasks = torch.ones(dim0, dim1).cuda() elif not torch.is_tensor(qmasks) and torch.is_tensor(kmasks): dim0, dim1 = kmasks.size(0), queries.size(1) qmasks = torch.ones(dim0, dim1).cuda() else: pass attention = self.similarity(queries, keys) if torch.is_tensor(qmasks) and torch.is_tensor(kmasks): qmasks = qmasks.repeat(queries.size(0) // qmasks.size(0), 1).unsqueeze(2) kmasks = kmasks.repeat(keys.size(0) // kmasks.size(0), 1).unsqueeze(2) attnmasks = torch.bmm(qmasks, kmasks.transpose(1, 2)) attention = torch.clip(attention, min=-10, max=10) attention = attention.exp() attention = attention * attnmasks attention = attention / (attention.sum(2).unsqueeze(2) + 1e-5) else: attention = self.softmax(attention) return attention @torch.no_grad() def save_attention_maps(self, input, output): self.attention_maps.append(output.data.detach().cpu().numpy()) class MultiheadAttention(nn.Module): def __init__(self, d, h, sim='dot', analysis=False, store_qk=False): super().__init__() assert d % h == 0, f"{d} dimension, {h} heads" self.h = h p = d // h self.project_queries = nn.Linear(d, d) self.project_keys = nn.Linear(d, d) self.project_values = nn.Linear(d, d) self.concatenation = nn.Linear(d, d) self.attention = Attention(sim, p, store_qk) if analysis: self.attention.register_forward_hook(save_attention_maps) def release_qk(self): Q, K = self.attention.release_qk() Qb = Q.size(0) // self.h Qn, Qd = Q.size(1), Q.size(2) Kb = K.size(0) // self.h Kn, Kd = K.size(1), K.size(2) Q = Q.view(self.h, Qb, Qn, Qd) K = K.view(self.h, Kb, Kn, Kd) Q = Q.permute(1, 2, 0, 3).contiguous().view(Qb, Qn, Qd * self.h) K = K.permute(1, 2, 0, 3).contiguous().view(Kb, Kn, Kd * self.h) return Q, K def forward(self, queries, keys, values, qmasks=None, kmasks=None): h = self.h b, n, d = queries.size() _, m, _ = keys.size() p = d // h queries = self.project_queries(queries) # shape [b, n, d] keys = self.project_keys(keys) # shape [b, m, d] values = self.project_values(values) # shape [b, m, d] queries = queries.view(b, n, h, p) keys = keys.view(b, m, h, p) values = values.view(b, m, h, p) queries = queries.permute(2, 0, 1, 3).contiguous().view(h * b, n, p) keys = keys.permute(2, 0, 1, 3).contiguous().view(h * b, m, p) values = values.permute(2, 0, 1, 3).contiguous().view(h * b, m, p) attn_w = self.attention(queries, keys, qmasks, kmasks) # shape [h * b, n, p] output = torch.bmm(attn_w, values) output = output.view(h, b, n, p) output = output.permute(1, 2, 0, 3).contiguous().view(b, n, d) output = self.concatenation(output) # shape [b, n, d] return output, attn_w class MultiheadAttentionExpanded(nn.Module): def __init__(self, d, h, sim='dot', analysis=False): super().__init__() self.project_queries = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) self.project_keys = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) self.project_values = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) self.concatenation = nn.Linear(h * d, d) self.attention = Attention(sim, d) if analysis: self.attention.register_forward_hook(save_attention_maps) def forward(self, queries, keys, values, qmasks=None, kmasks=None): output = [] for Wq, Wk, Wv in zip(self.project_queries, self.project_keys, self.project_values): Pq, Pk, Pv = Wq(queries), Wk(keys), Wv(values) output.append(torch.bmm(self.attention(Pq, Pk, qmasks, kmasks), Pv)) output = self.concatenation(torch.cat(output, 1)) return output class EmptyModule(nn.Module): def __init__(self, args): super().__init__() def forward(self, x): return 0. class RFF(nn.Module): def __init__(self, h): super().__init__() self.rff = nn.Sequential(nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU()) def forward(self, x): return self.rff(x) class MultiheadAttentionBlock(nn.Module): def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): super().__init__() self.multihead = MultiheadAttention(d, h, similarity, analysis, store_qk) if not full_head else MultiheadAttentionExpanded(d, h, similarity, analysis) self.layer_norm1 = nn.LayerNorm(d) self.layer_norm2 = nn.LayerNorm(d) self.rff = rff def release_qk(self): Q, K = self.multihead.release_qk() return Q, K def forward(self, x, y, xm=None, ym=None, layer_norm=True): h, a = self.multihead(x, y, y, xm, ym) if layer_norm: h = self.layer_norm1(x + h) return self.layer_norm2(h + self.rff(h)), a else: h = x + h return h + self.rff(h), a class SetAttentionBlock(nn.Module): def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): super().__init__() self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) def release_qk(self): Q, K = self.mab.release_qk() return Q, K def forward(self, x, m=None, ln=True): return self.mab(x, x, m, m, ln) class InducedSetAttentionBlock(nn.Module): def __init__(self, d, m, h, rff1, rff2, similarity='dot', full_head=False, analysis=False, store_qk=False): super().__init__() self.mab1 = MultiheadAttentionBlock(d, h, rff1, similarity, full_head, analysis, store_qk) self.mab2 = MultiheadAttentionBlock(d, h, rff2, similarity, full_head, analysis, store_qk) self.inducing_points = nn.Parameter(torch.randn(1, m, d)) def release_qk(self): raise NotImplemented def forward(self, x, m=None, ln=True): b = x.size(0) p = self.inducing_points p = p.repeat([b, 1, 1]) # shape [b, m, d] h = self.mab1(p, x, None, m, ln) # shape [b, m, d] return self.mab2(x, h, m, None, ln) class PoolingMultiheadAttention(nn.Module): def __init__(self, d, k, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): super().__init__() self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) self.seed_vectors = nn.Parameter(torch.randn(1, k, d)) torch.nn.init.xavier_uniform_(self.seed_vectors) def release_qk(self): Q, K = self.mab.release_qk() return Q, K def forward(self, z, m=None, ln=True): b = z.size(0) s = self.seed_vectors s = s.repeat([b, 1, 1]) # random seed vector: shape [b, k, d] return self.mab(s, z, None, m, ln) class PoolingMultiheadCrossAttention(nn.Module): def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): super().__init__() self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) def release_qk(self): Q, K = self.mab.release_qk() return Q, K def forward(self, X, Y, Xm=None, Ym=None, ln=True): return self.mab(X, Y, Xm, Ym, ln)