libokj's picture
Upload 299 files
22761bf verified
raw
history blame
30.6 kB
import torch
from torch import nn
class ArkDTA(nn.Module):
def __init__(self, args):
super(Net, self).__init__()
self.layer = nn.ModuleDict()
analysis_mode = args.analysis_mode
h = args.arkdta_hidden_dim
d = args.hp_dropout_rate
esm = args.arkdta_esm_model
esm_freeze = args.arkdta_esm_freeze
E = args.arkdta_ecfpvec_dim
L = args.arkdta_sab_depth
A = args.arkdta_attention_option
K = args.arkdta_num_heads
assert 'ARKMAB' in args.arkdta_residue_addon
self.layer['prot_encoder'] = FastaESM(h, esm, esm_freeze, analysis_mode)
self.layer['comp_encoder'] = EcfpConverter(h, L, E, analysis_mode)
self.layer['intg_arkmab'] = load_residue_addon(args)
self.layer['intg_pooling'] = load_complex_decoder(args)
self.layer['ba_predictor'] = AffinityMLP(h)
self.layer['dt_predictor'] = InteractionMLP(h)
def load_auxiliary_materials(self, **kwargs):
return_batch = kwargs['return_batch']
b = kwargs['atomresi_adj'].size(0)
x, y, z = kwargs['encoder_attention'].size()
logits0 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, :-1].sum(2).unsqueeze(
2) # actual compsub
logits1 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, -1].unsqueeze(2) # inactive site
return_batch['task/es_pred'] = torch.cat([logits1, logits0], 2)
return_batch['task/es_true'] = (kwargs['atomresi_adj'].sum(1) > 0.).long().squeeze(1)
return_batch['mask/es_resi'] = (kwargs['atomresi_masks'].sum(1) > 0.).float().squeeze(1)
return return_batch
def forward(self, batch):
return_batch = dict()
residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2]
ecfp_words, ecfp_masks = batch[3], batch[4]
atomresi_adj, atomresi_masks = batch[5], batch[6]
bav, dti, cids = batch[7], batch[8], batch[-1]
# Protein Encoder Module
residue_features = self.layer['prot_encoder'](X=residue_features,
fastas=residue_fastas,
masks=residue_masks)
residue_masks = residue_features[1]
residue_temps = residue_features[2]
protein_features = residue_features[3]
residue_features = residue_features[0]
return_batch['temp/lm_related'] = residue_temps * 0.
# Ligand Encoder Module
cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words,
ecfp_masks=ecfp_masks)
cstruct_masks = cstruct_features[1]
cstruct_features = cstruct_features[0]
# Protein-Ligand Integration Module (ARK-MAB)
residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks,
ligelem_features=cstruct_features, ligelem_masks=cstruct_masks)
residue_features, residue_masks, attention_weights = residue_results
del residue_results;
torch.cuda.empty_cache()
# Protein-Ligand Integration Module (Pooling Layer)
complex_results = self.layer['intg_pooling'](residue_features=residue_features,
residue_masks=residue_masks,
attention_weights=attention_weights,
protein_features=protein_features)
binding_complex, _, _, _ = complex_results
del complex_results;
torch.cuda.empty_cache()
# Drug-Target Outcome Predictor
bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex)
dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex)
return_batch['task/ba_pred'] = bav_predicted.view(-1)
return_batch['task/dt_pred'] = dti_predicted.view(-1)
return_batch['task/ba_true'] = bav.view(-1)
return_batch['task/dt_true'] = dti.view(-1)
return_batch['meta/cid'] = cids
# Additional Materials for Calculating Auxiliary Loss
return_batch = self.load_auxiliary_materials(return_batch=return_batch,
atomresi_adj=atomresi_adj,
atomresi_masks=atomresi_masks,
encoder_attention=attention_weights)
return return_batch
@torch.no_grad()
def infer(self, batch):
return_batch = dict()
residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2]
ecfp_words, ecfp_masks = batch[3], batch[4]
bav, dti, cids = batch[7], batch[8], batch[-1]
# Protein Encoder Module
residue_features = self.layer['prot_encoder'](X=residue_features,
fastas=residue_fastas,
masks=residue_masks)
residue_masks = residue_features[1]
residue_temps = residue_features[2]
protein_features = residue_features[3]
residue_features = residue_features[0]
return_batch['temp/lm_related'] = residue_temps * 0.
# Ligand Encoder Module
cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words,
ecfp_masks=ecfp_masks)
cstruct_masks = cstruct_features[1]
cstruct_features = cstruct_features[0]
# Protein-Ligand Integration Module (ARK-MAB)
residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks,
ligelem_features=cstruct_features, ligelem_masks=cstruct_masks)
residue_features, residue_masks, attention_weights = residue_results
del residue_results;
torch.cuda.empty_cache()
# Protein-Ligand Integration Module (Pooling Layer)
complex_results = self.layer['intg_pooling'](residue_features=residue_features,
residue_masks=residue_masks,
attention_weights=attention_weights,
protein_features=protein_features)
binding_complex, _, _, _ = complex_results
del complex_results;
torch.cuda.empty_cache()
# Drug-Target Outcome Predictor
bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex)
dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex)
return_batch['task/ba_pred'] = bav_predicted.view(-1)
return_batch['task/dt_pred'] = dti_predicted.view(-1)
return_batch['task/ba_true'] = bav.view(-1)
return_batch['task/dt_true'] = dti.view(-1)
return_batch['meta/cid'] = cids
return return_batch
class GraphDenseSequential(nn.Sequential):
def __init__(self, *args):
super(GraphDenseSequential, self).__init__(*args)
def forward(self, X, adj, mask):
for module in self._modules.values():
try:
X = module(X, adj, mask)
except BaseException:
X = module(X)
return X
class MaskedGlobalPooling(nn.Module):
def __init__(self, pooling='max'):
super(MaskedGlobalPooling, self).__init__()
self.pooling = pooling
def forward(self, x, adj, masks):
if x.dim() == 2:
x = x.unsqueeze(0)
# print(x, adj, masks)
masks = masks.unsqueeze(2).repeat(1, 1, x.size(2))
if self.pooling == 'max':
x[masks == 0] = -99999.99999
x = x.max(1)[0]
elif self.pooling == 'add':
x = x.sum(1)
else:
print('Not Implemented')
return x
class MaskedMean(nn.Module):
def __init__(self):
super(MaskedMean, self).__init__()
def forward(self, X, m):
if isinstance(m, torch.Tensor):
X = X * m.unsqueeze(2)
return X.mean(1)
class MaskedMax(nn.Module):
def __init__(self):
super(MaskedMax, self).__init__()
def forward(self, X, m):
if isinstance(m, torch.Tensor):
X = X * m.unsqueeze(2)
return torch.max(X, 1)[0]
class MaskedSum(nn.Module):
def __init__(self):
super(MaskedSum, self).__init__()
def forward(self, X, m):
if isinstance(m, torch.Tensor):
X = X * m.unsqueeze(2)
return X.sum(1)
class MaskedScaledAverage(nn.Module):
def __init__(self):
super(MaskedScaledAverage, self).__init__()
def forward(self, X, m):
if isinstance(m, torch.Tensor):
X = X * m.unsqueeze(2)
return X.sum(1) / (m.sum(1) ** 0.5).unsqueeze(1)
class Decoder(nn.Module):
def __init__(self, analysis_mode):
super(Decoder, self).__init__()
self.output_representations = []
self.query_representations = []
self.kvpair_representations = []
self.attention_weights = []
if analysis_mode: self.register_forward_hook(store_decoder_representations)
def show(self):
print("Number of Saved Numpy Arrays: ", len(self.representations))
for i, representation in enumerate(self.representations):
print(f"Shape of {i}th Numpy Array: ", representation.shape)
return self.representations
def flush(self):
del self.representations
self.representations = []
def release_qk(self):
return None
def forward(self, **kwargs):
return kwargs['X'], kwargs['X'], kwargs['residue_features'], None
class DecoderPMA_Residue(Decoder):
def __init__(self, h: int, num_heads: int, num_seeds: int, attn_option: str, analysis_mode: bool):
super(DecoderPMA_Residue, self).__init__(analysis_mode)
# Aggregate the Residues into Residue Regions
pma_args = (h, num_seeds, num_heads, RFF(h), attn_option, False, analysis_mode, False)
self.decoder = PoolingMultiheadAttention(*pma_args)
# Model Region-Region Interaction through Set Attention
sab_depth = 0 if num_seeds < 4 else int((num_seeds // 2) ** 0.5)
sab_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, True)
self.pairwise = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)])
# Concat, then reduce into h-dimensional Set Representation
self.aggregate = nn.Linear(h * num_seeds, h)
self.apply(initialization)
def forward(self, **kwargs):
residue_features = kwargs['residue_features']
residue_masks = kwargs['residue_masks']
output, attention = self.decoder(residue_features, residue_masks)
for sab in self.pairwise: output, _ = sab(output)
b, n, d = output.shape
output = self.aggregate(output.view(b, n * d))
return output, None, residue_features, attention
class AffinityMLP(nn.Module):
def __init__(self, h: int):
super(AffinityMLP, self).__init__()
self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1))
self.apply(initialization)
def forward(self, **kwargs):
'''
X: batch size x 1 x H
'''
X = kwargs['binding_complex']
X = X.squeeze(1) if X.dim() == 3 else X
yhat = self.mlp(X)
return yhat
class InteractionMLP(nn.Module):
def __init__(self, h: int):
super(InteractionMLP, self).__init__()
self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1), nn.Sigmoid())
self.apply(initialization)
def forward(self, **kwargs):
'''
X: batch size x 1 x H
'''
X = kwargs['binding_complex']
X = X.squeeze(1) if X.dim() == 3 else X
yhat = self.mlp(X)
return yhat
class LigelemEncoder(nn.Module):
def __init__(self):
super(LigelemEncoder, self).__init__()
self.representations = []
def show(self):
print("Number of Saved Numpy Arrays: ", len(self.representations))
for i, representation in enumerate(self.representations):
print(f"Shape of {i}th Numpy Array: ", representation.shape)
return self.representations
def flush(self):
del self.representations
self.representations = []
class EcfpConverter(LigelemEncoder):
def __init__(self, h: int, sab_depth: int, ecfp_dim: int, analysis_mode: bool):
super(EcfpConverter, self).__init__()
K = 4 # number of attention heads
self.ecfp_embeddings = nn.Embedding(ecfp_dim + 1, h, padding_idx=ecfp_dim)
self.encoder = nn.ModuleList([])
sab_args = (h, K, RFF(h), 'general_dot', False, analysis_mode, True)
self.encoder = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)])
self.representations = []
if analysis_mode: self.register_forward_hook(store_elemwise_representations)
self.apply(initialization)
def forward(self, **kwargs):
'''
X : (b x d)
'''
ecfp_words = kwargs['ecfp_words']
ecfp_masks = kwargs['ecfp_masks']
ecfp_words = self.ecfp_embeddings(ecfp_words)
for sab in self.encoder:
ecfp_words, _ = sab(ecfp_words, ecfp_masks)
return [ecfp_words, ecfp_masks]
class ResidueAddOn(nn.Module):
def __init__(self):
super(ResidueAddOn, self).__init__()
self.representations = []
def show(self):
print("Number of Saved Numpy Arrays: ", len(self.representations))
for i, representation in enumerate(self.representations):
print(f"Shape of {i}th Numpy Array: ", representation.shape)
return self.representations
def flush(self):
del self.representations
self.representations = []
def forward(self, **kwargs):
X, Xm = kwargs['X'], kwargs['Xm']
return X, Xm
class ARKMAB(ResidueAddOn):
def __init__(self, h: int, num_heads: int, attn_option: str, analysis_mode: bool):
super(ARKMAB, self).__init__()
pmx_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, False)
self.pmx = PoolingMultiheadCrossAttention(*pmx_args)
self.inactive = nn.Parameter(torch.randn(1, 1, h))
self.fillmask = nn.Parameter(torch.ones(1, 1), requires_grad=False)
self.representations = []
if analysis_mode: pass
self.apply(initialization)
def forward(self, **kwargs):
'''
X: batch size x residues x H
Xm: batch size x residues x H
Y: batch size x ecfpsubs x H
Ym: batch size x ecfpsubs x H
'''
X, Xm = kwargs['residue_features'], kwargs['residue_masks']
Y, Ym = kwargs['ligelem_features'], kwargs['ligelem_masks']
pseudo_substructure = self.inactive.repeat(X.size(0), 1, 1)
pseudo_masks = self.fillmask.repeat(X.size(0), 1)
Y = torch.cat([Y, pseudo_substructure], 1)
Ym = torch.cat([Ym, pseudo_masks], 1)
X, attention = self.pmx(Y=Y, Ym=Ym, X=X, Xm=Xm)
return X, Xm, attention
class ResidueEncoder(nn.Module):
def __init__(self):
super(ResidueEncoder, self).__init__()
self.representations = []
def show(self):
print("Number of Saved Numpy Arrays: ", len(self.representations))
for i, representation in enumerate(self.representations):
print(f"Shape of {i}th Numpy Array: ", representation.shape)
return self.representations
def flush(self):
del self.representations
self.representations = []
class AminoAcidSeqCNN(ResidueEncoder):
def __init__(self, h: int, d: float, cnn_depth: int, kernel_size: int, analysis_mode: bool):
super(AminoAcidSeqCNN, self).__init__()
self.encoder = nn.ModuleList([nn.Sequential(nn.Linear(21, h), # Warning
nn.Dropout(d),
nn.LeakyReLU(),
nn.Linear(h, h))])
for _ in range(cnn_depth):
self.encoder.append(nn.Conv1d(h, h, kernel_size, 1, (kernel_size - 1) // 2))
self.representations = []
if analysis_mode: self.register_forward_hook(store_representations)
self.apply(initialization)
def forward(self, **kwargs):
X = kwargs['aaseqs']
for i, module in enumerate(self.encoder):
if i == 1: X = X.transpose(1, 2)
X = module(X)
X = X.transpose(1, 2)
return X
class FastaESM(ResidueEncoder):
def __init__(self, h: int, esm_model: str, esm_freeze: bool, analysis_mode: bool):
super(FastaESM, self).__init__()
self.esm_version = 2 if 'esm2' in esm_model else 1
if esm_model == 'esm1b_t33_650M_UR505':
self.esm, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
self.layer_idx, self.emb_dim = 33, 1024
elif esm_model == 'esm1_t12_85M_UR505':
self.esm, alphabet = esm.pretrained.esm1_t12_85M_UR50S()
self.layer_idx, self.emb_dim = 12, 768
elif esm_model == 'esm2_t6_8M_UR50D':
self.esm, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
self.layer_idx, self.emb_dim = 6, 320
elif esm_model == 'esm2_t12_35M_UR50D':
self.esm, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
self.layer_idx, self.emb_dim = 12, 480
elif esm_model == 'esm2_t30_150M_UR50D':
self.esm, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
self.layer_idx, self.emb_dim = 30, 640
else:
raise
self.batch_converter = alphabet.get_batch_converter()
if esm_freeze == 'True':
for p in self.esm.parameters():
p.requires_grad = False
assert h == self.emb_dim, f"The hidden dimension should be set to {self.emb_dim}, not {h}"
self.representations = []
if analysis_mode: self.register_forward_hook(store_elemwise_representations)
def esm1_pooling(self, embeddings, masks):
return embeddings[:, 1:, :].sum(1) / masks[:, 1:].sum(1).view(-1, 1)
def esm2_pooling(self, embeddings, masks):
return embeddings[:, 1:-1, :].sum(1) / masks[:, 1:-1].sum(1).view(-1, 1)
def forward(self, **kwargs):
fastas = kwargs['fastas']
_, _, tokenized = self.batch_converter(fastas)
tokenized = tokenized.cuda()
if self.esm_version == 2:
masks = torch.where(tokenized > 1, 1, 0).float()
else:
masks = torch.where((tokenized > 1) & (tokenized != 32), 1, 0).float()
embeddings = self.esm(tokenized, repr_layers=[self.layer_idx], return_contacts=True)
logits = embeddings["logits"].sum()
contacts = embeddings["contacts"].sum()
attentions = embeddings["attentions"].sum()
embeddings = embeddings["representations"][self.layer_idx]
assert masks.size(0) == embeddings.size(
0), f"Batch sizes of masks {masks.size(0)} and {embeddings.size(0)} do not match."
assert masks.size(1) == embeddings.size(
1), f"Lengths of masks {masks.size(1)} and {embeddings.size(1)} do not match."
if self.esm_version == 2:
return [embeddings[:, 1:-1, :], masks[:, 1:-1], logits + contacts + attentions,
self.esm2_pooling(embeddings, masks)]
else:
return [embeddings[:, 1:, :], masks[:, 1:], logits + contacts + attentions,
self.esm1_pooling(embeddings, masks)]
class DotProduct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, queries, keys):
return torch.bmm(queries, keys.transpose(1, 2))
class ScaledDotProduct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, queries, keys):
return torch.bmm(queries, keys.transpose(1, 2)) / (queries.size(2) ** 0.5)
class GeneralDotProduct(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.W = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
torch.nn.init.orthogonal_(self.W)
def forward(self, queries, keys):
return torch.bmm(queries @ self.W, keys.transpose(1, 2))
class ConcatDotProduct(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
raise
def forward(self, queries, keys):
return
class Additive(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.U = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
self.T = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
self.b = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
self.W = nn.Sequential(nn.Tanh(), nn.Linear(hidden_dim, 1))
torch.nn.init.orthogonal_(self.U)
torch.nn.init.orthogonal_(self.T)
def forward(self, queries, keys):
return self.W(queries.unsqueeze(1) @ self.U + keys.unsqueeze(2) @ self.T + self.b).squeeze(-1).transpose(1, 2)
class Attention(nn.Module):
def __init__(self, similarity, hidden_dim=1024, store_qk=False):
super().__init__()
self.softmax = nn.Softmax(dim=2)
self.attention_maps = []
self.store_qk = store_qk
self.query_vectors, self.key_vectors = None, None
assert similarity in ['dot', 'scaled_dot', 'general_dot', 'concat_dot', 'additive']
if similarity == 'dot':
self.similarity = DotProduct()
elif similarity == 'scaled_dot':
self.similarity = ScaledDotProduct()
elif similarity == 'general_dot':
self.similarity = GeneralDotProduct(hidden_dim)
elif similarity == 'concat_dot':
self.similarity = ConcatDotProduct(hidden_dim)
elif similarity == 'additive':
self.similarity = Additive(hidden_dim)
else:
raise
def release_qk(self):
Q, K = self.query_vectors, self.key_vectors
self.query_vectors, self.key_vectors = None, None
torch.cuda.empty_cache()
return Q, K
def forward(self, queries, keys, qmasks=None, kmasks=None):
if self.store_qk:
self.query_vectors = queries
self.key_vectors = keys
if torch.is_tensor(qmasks) and not torch.is_tensor(kmasks):
dim0, dim1 = qmasks.size(0), keys.size(1)
kmasks = torch.ones(dim0, dim1).cuda()
elif not torch.is_tensor(qmasks) and torch.is_tensor(kmasks):
dim0, dim1 = kmasks.size(0), queries.size(1)
qmasks = torch.ones(dim0, dim1).cuda()
else:
pass
attention = self.similarity(queries, keys)
if torch.is_tensor(qmasks) and torch.is_tensor(kmasks):
qmasks = qmasks.repeat(queries.size(0) // qmasks.size(0), 1).unsqueeze(2)
kmasks = kmasks.repeat(keys.size(0) // kmasks.size(0), 1).unsqueeze(2)
attnmasks = torch.bmm(qmasks, kmasks.transpose(1, 2))
attention = torch.clip(attention, min=-10, max=10)
attention = attention.exp()
attention = attention * attnmasks
attention = attention / (attention.sum(2).unsqueeze(2) + 1e-5)
else:
attention = self.softmax(attention)
return attention
@torch.no_grad()
def save_attention_maps(self, input, output):
self.attention_maps.append(output.data.detach().cpu().numpy())
class MultiheadAttention(nn.Module):
def __init__(self, d, h, sim='dot', analysis=False, store_qk=False):
super().__init__()
assert d % h == 0, f"{d} dimension, {h} heads"
self.h = h
p = d // h
self.project_queries = nn.Linear(d, d)
self.project_keys = nn.Linear(d, d)
self.project_values = nn.Linear(d, d)
self.concatenation = nn.Linear(d, d)
self.attention = Attention(sim, p, store_qk)
if analysis:
self.attention.register_forward_hook(save_attention_maps)
def release_qk(self):
Q, K = self.attention.release_qk()
Qb = Q.size(0) // self.h
Qn, Qd = Q.size(1), Q.size(2)
Kb = K.size(0) // self.h
Kn, Kd = K.size(1), K.size(2)
Q = Q.view(self.h, Qb, Qn, Qd)
K = K.view(self.h, Kb, Kn, Kd)
Q = Q.permute(1, 2, 0, 3).contiguous().view(Qb, Qn, Qd * self.h)
K = K.permute(1, 2, 0, 3).contiguous().view(Kb, Kn, Kd * self.h)
return Q, K
def forward(self, queries, keys, values, qmasks=None, kmasks=None):
h = self.h
b, n, d = queries.size()
_, m, _ = keys.size()
p = d // h
queries = self.project_queries(queries) # shape [b, n, d]
keys = self.project_keys(keys) # shape [b, m, d]
values = self.project_values(values) # shape [b, m, d]
queries = queries.view(b, n, h, p)
keys = keys.view(b, m, h, p)
values = values.view(b, m, h, p)
queries = queries.permute(2, 0, 1, 3).contiguous().view(h * b, n, p)
keys = keys.permute(2, 0, 1, 3).contiguous().view(h * b, m, p)
values = values.permute(2, 0, 1, 3).contiguous().view(h * b, m, p)
attn_w = self.attention(queries, keys, qmasks, kmasks) # shape [h * b, n, p]
output = torch.bmm(attn_w, values)
output = output.view(h, b, n, p)
output = output.permute(1, 2, 0, 3).contiguous().view(b, n, d)
output = self.concatenation(output) # shape [b, n, d]
return output, attn_w
class MultiheadAttentionExpanded(nn.Module):
def __init__(self, d, h, sim='dot', analysis=False):
super().__init__()
self.project_queries = nn.ModuleList([nn.Linear(d, d) for _ in range(h)])
self.project_keys = nn.ModuleList([nn.Linear(d, d) for _ in range(h)])
self.project_values = nn.ModuleList([nn.Linear(d, d) for _ in range(h)])
self.concatenation = nn.Linear(h * d, d)
self.attention = Attention(sim, d)
if analysis:
self.attention.register_forward_hook(save_attention_maps)
def forward(self, queries, keys, values, qmasks=None, kmasks=None):
output = []
for Wq, Wk, Wv in zip(self.project_queries, self.project_keys, self.project_values):
Pq, Pk, Pv = Wq(queries), Wk(keys), Wv(values)
output.append(torch.bmm(self.attention(Pq, Pk, qmasks, kmasks), Pv))
output = self.concatenation(torch.cat(output, 1))
return output
class EmptyModule(nn.Module):
def __init__(self, args):
super().__init__()
def forward(self, x):
return 0.
class RFF(nn.Module):
def __init__(self, h):
super().__init__()
self.rff = nn.Sequential(nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU())
def forward(self, x):
return self.rff(x)
class MultiheadAttentionBlock(nn.Module):
def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False):
super().__init__()
self.multihead = MultiheadAttention(d, h, similarity, analysis,
store_qk) if not full_head else MultiheadAttentionExpanded(d, h, similarity,
analysis)
self.layer_norm1 = nn.LayerNorm(d)
self.layer_norm2 = nn.LayerNorm(d)
self.rff = rff
def release_qk(self):
Q, K = self.multihead.release_qk()
return Q, K
def forward(self, x, y, xm=None, ym=None, layer_norm=True):
h, a = self.multihead(x, y, y, xm, ym)
if layer_norm:
h = self.layer_norm1(x + h)
return self.layer_norm2(h + self.rff(h)), a
else:
h = x + h
return h + self.rff(h), a
class SetAttentionBlock(nn.Module):
def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False):
super().__init__()
self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk)
def release_qk(self):
Q, K = self.mab.release_qk()
return Q, K
def forward(self, x, m=None, ln=True):
return self.mab(x, x, m, m, ln)
class InducedSetAttentionBlock(nn.Module):
def __init__(self, d, m, h, rff1, rff2, similarity='dot', full_head=False, analysis=False, store_qk=False):
super().__init__()
self.mab1 = MultiheadAttentionBlock(d, h, rff1, similarity, full_head, analysis, store_qk)
self.mab2 = MultiheadAttentionBlock(d, h, rff2, similarity, full_head, analysis, store_qk)
self.inducing_points = nn.Parameter(torch.randn(1, m, d))
def release_qk(self):
raise NotImplemented
def forward(self, x, m=None, ln=True):
b = x.size(0)
p = self.inducing_points
p = p.repeat([b, 1, 1]) # shape [b, m, d]
h = self.mab1(p, x, None, m, ln) # shape [b, m, d]
return self.mab2(x, h, m, None, ln)
class PoolingMultiheadAttention(nn.Module):
def __init__(self, d, k, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False):
super().__init__()
self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk)
self.seed_vectors = nn.Parameter(torch.randn(1, k, d))
torch.nn.init.xavier_uniform_(self.seed_vectors)
def release_qk(self):
Q, K = self.mab.release_qk()
return Q, K
def forward(self, z, m=None, ln=True):
b = z.size(0)
s = self.seed_vectors
s = s.repeat([b, 1, 1]) # random seed vector: shape [b, k, d]
return self.mab(s, z, None, m, ln)
class PoolingMultiheadCrossAttention(nn.Module):
def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False):
super().__init__()
self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk)
def release_qk(self):
Q, K = self.mab.release_qk()
return Q, K
def forward(self, X, Y, Xm=None, Ym=None, ln=True):
return self.mab(X, Y, Xm, Ym, ln)