Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from torch import nn | |
class ArkDTA(nn.Module): | |
def __init__(self, args): | |
super(Net, self).__init__() | |
self.layer = nn.ModuleDict() | |
analysis_mode = args.analysis_mode | |
h = args.arkdta_hidden_dim | |
d = args.hp_dropout_rate | |
esm = args.arkdta_esm_model | |
esm_freeze = args.arkdta_esm_freeze | |
E = args.arkdta_ecfpvec_dim | |
L = args.arkdta_sab_depth | |
A = args.arkdta_attention_option | |
K = args.arkdta_num_heads | |
assert 'ARKMAB' in args.arkdta_residue_addon | |
self.layer['prot_encoder'] = FastaESM(h, esm, esm_freeze, analysis_mode) | |
self.layer['comp_encoder'] = EcfpConverter(h, L, E, analysis_mode) | |
self.layer['intg_arkmab'] = load_residue_addon(args) | |
self.layer['intg_pooling'] = load_complex_decoder(args) | |
self.layer['ba_predictor'] = AffinityMLP(h) | |
self.layer['dt_predictor'] = InteractionMLP(h) | |
def load_auxiliary_materials(self, **kwargs): | |
return_batch = kwargs['return_batch'] | |
b = kwargs['atomresi_adj'].size(0) | |
x, y, z = kwargs['encoder_attention'].size() | |
logits0 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, :-1].sum(2).unsqueeze( | |
2) # actual compsub | |
logits1 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, -1].unsqueeze(2) # inactive site | |
return_batch['task/es_pred'] = torch.cat([logits1, logits0], 2) | |
return_batch['task/es_true'] = (kwargs['atomresi_adj'].sum(1) > 0.).long().squeeze(1) | |
return_batch['mask/es_resi'] = (kwargs['atomresi_masks'].sum(1) > 0.).float().squeeze(1) | |
return return_batch | |
def forward(self, batch): | |
return_batch = dict() | |
residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2] | |
ecfp_words, ecfp_masks = batch[3], batch[4] | |
atomresi_adj, atomresi_masks = batch[5], batch[6] | |
bav, dti, cids = batch[7], batch[8], batch[-1] | |
# Protein Encoder Module | |
residue_features = self.layer['prot_encoder'](X=residue_features, | |
fastas=residue_fastas, | |
masks=residue_masks) | |
residue_masks = residue_features[1] | |
residue_temps = residue_features[2] | |
protein_features = residue_features[3] | |
residue_features = residue_features[0] | |
return_batch['temp/lm_related'] = residue_temps * 0. | |
# Ligand Encoder Module | |
cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words, | |
ecfp_masks=ecfp_masks) | |
cstruct_masks = cstruct_features[1] | |
cstruct_features = cstruct_features[0] | |
# Protein-Ligand Integration Module (ARK-MAB) | |
residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks, | |
ligelem_features=cstruct_features, ligelem_masks=cstruct_masks) | |
residue_features, residue_masks, attention_weights = residue_results | |
del residue_results; | |
torch.cuda.empty_cache() | |
# Protein-Ligand Integration Module (Pooling Layer) | |
complex_results = self.layer['intg_pooling'](residue_features=residue_features, | |
residue_masks=residue_masks, | |
attention_weights=attention_weights, | |
protein_features=protein_features) | |
binding_complex, _, _, _ = complex_results | |
del complex_results; | |
torch.cuda.empty_cache() | |
# Drug-Target Outcome Predictor | |
bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex) | |
dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex) | |
return_batch['task/ba_pred'] = bav_predicted.view(-1) | |
return_batch['task/dt_pred'] = dti_predicted.view(-1) | |
return_batch['task/ba_true'] = bav.view(-1) | |
return_batch['task/dt_true'] = dti.view(-1) | |
return_batch['meta/cid'] = cids | |
# Additional Materials for Calculating Auxiliary Loss | |
return_batch = self.load_auxiliary_materials(return_batch=return_batch, | |
atomresi_adj=atomresi_adj, | |
atomresi_masks=atomresi_masks, | |
encoder_attention=attention_weights) | |
return return_batch | |
def infer(self, batch): | |
return_batch = dict() | |
residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2] | |
ecfp_words, ecfp_masks = batch[3], batch[4] | |
bav, dti, cids = batch[7], batch[8], batch[-1] | |
# Protein Encoder Module | |
residue_features = self.layer['prot_encoder'](X=residue_features, | |
fastas=residue_fastas, | |
masks=residue_masks) | |
residue_masks = residue_features[1] | |
residue_temps = residue_features[2] | |
protein_features = residue_features[3] | |
residue_features = residue_features[0] | |
return_batch['temp/lm_related'] = residue_temps * 0. | |
# Ligand Encoder Module | |
cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words, | |
ecfp_masks=ecfp_masks) | |
cstruct_masks = cstruct_features[1] | |
cstruct_features = cstruct_features[0] | |
# Protein-Ligand Integration Module (ARK-MAB) | |
residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks, | |
ligelem_features=cstruct_features, ligelem_masks=cstruct_masks) | |
residue_features, residue_masks, attention_weights = residue_results | |
del residue_results; | |
torch.cuda.empty_cache() | |
# Protein-Ligand Integration Module (Pooling Layer) | |
complex_results = self.layer['intg_pooling'](residue_features=residue_features, | |
residue_masks=residue_masks, | |
attention_weights=attention_weights, | |
protein_features=protein_features) | |
binding_complex, _, _, _ = complex_results | |
del complex_results; | |
torch.cuda.empty_cache() | |
# Drug-Target Outcome Predictor | |
bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex) | |
dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex) | |
return_batch['task/ba_pred'] = bav_predicted.view(-1) | |
return_batch['task/dt_pred'] = dti_predicted.view(-1) | |
return_batch['task/ba_true'] = bav.view(-1) | |
return_batch['task/dt_true'] = dti.view(-1) | |
return_batch['meta/cid'] = cids | |
return return_batch | |
class GraphDenseSequential(nn.Sequential): | |
def __init__(self, *args): | |
super(GraphDenseSequential, self).__init__(*args) | |
def forward(self, X, adj, mask): | |
for module in self._modules.values(): | |
try: | |
X = module(X, adj, mask) | |
except BaseException: | |
X = module(X) | |
return X | |
class MaskedGlobalPooling(nn.Module): | |
def __init__(self, pooling='max'): | |
super(MaskedGlobalPooling, self).__init__() | |
self.pooling = pooling | |
def forward(self, x, adj, masks): | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
# print(x, adj, masks) | |
masks = masks.unsqueeze(2).repeat(1, 1, x.size(2)) | |
if self.pooling == 'max': | |
x[masks == 0] = -99999.99999 | |
x = x.max(1)[0] | |
elif self.pooling == 'add': | |
x = x.sum(1) | |
else: | |
print('Not Implemented') | |
return x | |
class MaskedMean(nn.Module): | |
def __init__(self): | |
super(MaskedMean, self).__init__() | |
def forward(self, X, m): | |
if isinstance(m, torch.Tensor): | |
X = X * m.unsqueeze(2) | |
return X.mean(1) | |
class MaskedMax(nn.Module): | |
def __init__(self): | |
super(MaskedMax, self).__init__() | |
def forward(self, X, m): | |
if isinstance(m, torch.Tensor): | |
X = X * m.unsqueeze(2) | |
return torch.max(X, 1)[0] | |
class MaskedSum(nn.Module): | |
def __init__(self): | |
super(MaskedSum, self).__init__() | |
def forward(self, X, m): | |
if isinstance(m, torch.Tensor): | |
X = X * m.unsqueeze(2) | |
return X.sum(1) | |
class MaskedScaledAverage(nn.Module): | |
def __init__(self): | |
super(MaskedScaledAverage, self).__init__() | |
def forward(self, X, m): | |
if isinstance(m, torch.Tensor): | |
X = X * m.unsqueeze(2) | |
return X.sum(1) / (m.sum(1) ** 0.5).unsqueeze(1) | |
class Decoder(nn.Module): | |
def __init__(self, analysis_mode): | |
super(Decoder, self).__init__() | |
self.output_representations = [] | |
self.query_representations = [] | |
self.kvpair_representations = [] | |
self.attention_weights = [] | |
if analysis_mode: self.register_forward_hook(store_decoder_representations) | |
def show(self): | |
print("Number of Saved Numpy Arrays: ", len(self.representations)) | |
for i, representation in enumerate(self.representations): | |
print(f"Shape of {i}th Numpy Array: ", representation.shape) | |
return self.representations | |
def flush(self): | |
del self.representations | |
self.representations = [] | |
def release_qk(self): | |
return None | |
def forward(self, **kwargs): | |
return kwargs['X'], kwargs['X'], kwargs['residue_features'], None | |
class DecoderPMA_Residue(Decoder): | |
def __init__(self, h: int, num_heads: int, num_seeds: int, attn_option: str, analysis_mode: bool): | |
super(DecoderPMA_Residue, self).__init__(analysis_mode) | |
# Aggregate the Residues into Residue Regions | |
pma_args = (h, num_seeds, num_heads, RFF(h), attn_option, False, analysis_mode, False) | |
self.decoder = PoolingMultiheadAttention(*pma_args) | |
# Model Region-Region Interaction through Set Attention | |
sab_depth = 0 if num_seeds < 4 else int((num_seeds // 2) ** 0.5) | |
sab_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, True) | |
self.pairwise = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)]) | |
# Concat, then reduce into h-dimensional Set Representation | |
self.aggregate = nn.Linear(h * num_seeds, h) | |
self.apply(initialization) | |
def forward(self, **kwargs): | |
residue_features = kwargs['residue_features'] | |
residue_masks = kwargs['residue_masks'] | |
output, attention = self.decoder(residue_features, residue_masks) | |
for sab in self.pairwise: output, _ = sab(output) | |
b, n, d = output.shape | |
output = self.aggregate(output.view(b, n * d)) | |
return output, None, residue_features, attention | |
class AffinityMLP(nn.Module): | |
def __init__(self, h: int): | |
super(AffinityMLP, self).__init__() | |
self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1)) | |
self.apply(initialization) | |
def forward(self, **kwargs): | |
''' | |
X: batch size x 1 x H | |
''' | |
X = kwargs['binding_complex'] | |
X = X.squeeze(1) if X.dim() == 3 else X | |
yhat = self.mlp(X) | |
return yhat | |
class InteractionMLP(nn.Module): | |
def __init__(self, h: int): | |
super(InteractionMLP, self).__init__() | |
self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1), nn.Sigmoid()) | |
self.apply(initialization) | |
def forward(self, **kwargs): | |
''' | |
X: batch size x 1 x H | |
''' | |
X = kwargs['binding_complex'] | |
X = X.squeeze(1) if X.dim() == 3 else X | |
yhat = self.mlp(X) | |
return yhat | |
class LigelemEncoder(nn.Module): | |
def __init__(self): | |
super(LigelemEncoder, self).__init__() | |
self.representations = [] | |
def show(self): | |
print("Number of Saved Numpy Arrays: ", len(self.representations)) | |
for i, representation in enumerate(self.representations): | |
print(f"Shape of {i}th Numpy Array: ", representation.shape) | |
return self.representations | |
def flush(self): | |
del self.representations | |
self.representations = [] | |
class EcfpConverter(LigelemEncoder): | |
def __init__(self, h: int, sab_depth: int, ecfp_dim: int, analysis_mode: bool): | |
super(EcfpConverter, self).__init__() | |
K = 4 # number of attention heads | |
self.ecfp_embeddings = nn.Embedding(ecfp_dim + 1, h, padding_idx=ecfp_dim) | |
self.encoder = nn.ModuleList([]) | |
sab_args = (h, K, RFF(h), 'general_dot', False, analysis_mode, True) | |
self.encoder = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)]) | |
self.representations = [] | |
if analysis_mode: self.register_forward_hook(store_elemwise_representations) | |
self.apply(initialization) | |
def forward(self, **kwargs): | |
''' | |
X : (b x d) | |
''' | |
ecfp_words = kwargs['ecfp_words'] | |
ecfp_masks = kwargs['ecfp_masks'] | |
ecfp_words = self.ecfp_embeddings(ecfp_words) | |
for sab in self.encoder: | |
ecfp_words, _ = sab(ecfp_words, ecfp_masks) | |
return [ecfp_words, ecfp_masks] | |
class ResidueAddOn(nn.Module): | |
def __init__(self): | |
super(ResidueAddOn, self).__init__() | |
self.representations = [] | |
def show(self): | |
print("Number of Saved Numpy Arrays: ", len(self.representations)) | |
for i, representation in enumerate(self.representations): | |
print(f"Shape of {i}th Numpy Array: ", representation.shape) | |
return self.representations | |
def flush(self): | |
del self.representations | |
self.representations = [] | |
def forward(self, **kwargs): | |
X, Xm = kwargs['X'], kwargs['Xm'] | |
return X, Xm | |
class ARKMAB(ResidueAddOn): | |
def __init__(self, h: int, num_heads: int, attn_option: str, analysis_mode: bool): | |
super(ARKMAB, self).__init__() | |
pmx_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, False) | |
self.pmx = PoolingMultiheadCrossAttention(*pmx_args) | |
self.inactive = nn.Parameter(torch.randn(1, 1, h)) | |
self.fillmask = nn.Parameter(torch.ones(1, 1), requires_grad=False) | |
self.representations = [] | |
if analysis_mode: pass | |
self.apply(initialization) | |
def forward(self, **kwargs): | |
''' | |
X: batch size x residues x H | |
Xm: batch size x residues x H | |
Y: batch size x ecfpsubs x H | |
Ym: batch size x ecfpsubs x H | |
''' | |
X, Xm = kwargs['residue_features'], kwargs['residue_masks'] | |
Y, Ym = kwargs['ligelem_features'], kwargs['ligelem_masks'] | |
pseudo_substructure = self.inactive.repeat(X.size(0), 1, 1) | |
pseudo_masks = self.fillmask.repeat(X.size(0), 1) | |
Y = torch.cat([Y, pseudo_substructure], 1) | |
Ym = torch.cat([Ym, pseudo_masks], 1) | |
X, attention = self.pmx(Y=Y, Ym=Ym, X=X, Xm=Xm) | |
return X, Xm, attention | |
class ResidueEncoder(nn.Module): | |
def __init__(self): | |
super(ResidueEncoder, self).__init__() | |
self.representations = [] | |
def show(self): | |
print("Number of Saved Numpy Arrays: ", len(self.representations)) | |
for i, representation in enumerate(self.representations): | |
print(f"Shape of {i}th Numpy Array: ", representation.shape) | |
return self.representations | |
def flush(self): | |
del self.representations | |
self.representations = [] | |
class AminoAcidSeqCNN(ResidueEncoder): | |
def __init__(self, h: int, d: float, cnn_depth: int, kernel_size: int, analysis_mode: bool): | |
super(AminoAcidSeqCNN, self).__init__() | |
self.encoder = nn.ModuleList([nn.Sequential(nn.Linear(21, h), # Warning | |
nn.Dropout(d), | |
nn.LeakyReLU(), | |
nn.Linear(h, h))]) | |
for _ in range(cnn_depth): | |
self.encoder.append(nn.Conv1d(h, h, kernel_size, 1, (kernel_size - 1) // 2)) | |
self.representations = [] | |
if analysis_mode: self.register_forward_hook(store_representations) | |
self.apply(initialization) | |
def forward(self, **kwargs): | |
X = kwargs['aaseqs'] | |
for i, module in enumerate(self.encoder): | |
if i == 1: X = X.transpose(1, 2) | |
X = module(X) | |
X = X.transpose(1, 2) | |
return X | |
class FastaESM(ResidueEncoder): | |
def __init__(self, h: int, esm_model: str, esm_freeze: bool, analysis_mode: bool): | |
super(FastaESM, self).__init__() | |
self.esm_version = 2 if 'esm2' in esm_model else 1 | |
if esm_model == 'esm1b_t33_650M_UR505': | |
self.esm, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() | |
self.layer_idx, self.emb_dim = 33, 1024 | |
elif esm_model == 'esm1_t12_85M_UR505': | |
self.esm, alphabet = esm.pretrained.esm1_t12_85M_UR50S() | |
self.layer_idx, self.emb_dim = 12, 768 | |
elif esm_model == 'esm2_t6_8M_UR50D': | |
self.esm, alphabet = esm.pretrained.esm2_t6_8M_UR50D() | |
self.layer_idx, self.emb_dim = 6, 320 | |
elif esm_model == 'esm2_t12_35M_UR50D': | |
self.esm, alphabet = esm.pretrained.esm2_t12_35M_UR50D() | |
self.layer_idx, self.emb_dim = 12, 480 | |
elif esm_model == 'esm2_t30_150M_UR50D': | |
self.esm, alphabet = esm.pretrained.esm2_t30_150M_UR50D() | |
self.layer_idx, self.emb_dim = 30, 640 | |
else: | |
raise | |
self.batch_converter = alphabet.get_batch_converter() | |
if esm_freeze == 'True': | |
for p in self.esm.parameters(): | |
p.requires_grad = False | |
assert h == self.emb_dim, f"The hidden dimension should be set to {self.emb_dim}, not {h}" | |
self.representations = [] | |
if analysis_mode: self.register_forward_hook(store_elemwise_representations) | |
def esm1_pooling(self, embeddings, masks): | |
return embeddings[:, 1:, :].sum(1) / masks[:, 1:].sum(1).view(-1, 1) | |
def esm2_pooling(self, embeddings, masks): | |
return embeddings[:, 1:-1, :].sum(1) / masks[:, 1:-1].sum(1).view(-1, 1) | |
def forward(self, **kwargs): | |
fastas = kwargs['fastas'] | |
_, _, tokenized = self.batch_converter(fastas) | |
tokenized = tokenized.cuda() | |
if self.esm_version == 2: | |
masks = torch.where(tokenized > 1, 1, 0).float() | |
else: | |
masks = torch.where((tokenized > 1) & (tokenized != 32), 1, 0).float() | |
embeddings = self.esm(tokenized, repr_layers=[self.layer_idx], return_contacts=True) | |
logits = embeddings["logits"].sum() | |
contacts = embeddings["contacts"].sum() | |
attentions = embeddings["attentions"].sum() | |
embeddings = embeddings["representations"][self.layer_idx] | |
assert masks.size(0) == embeddings.size( | |
0), f"Batch sizes of masks {masks.size(0)} and {embeddings.size(0)} do not match." | |
assert masks.size(1) == embeddings.size( | |
1), f"Lengths of masks {masks.size(1)} and {embeddings.size(1)} do not match." | |
if self.esm_version == 2: | |
return [embeddings[:, 1:-1, :], masks[:, 1:-1], logits + contacts + attentions, | |
self.esm2_pooling(embeddings, masks)] | |
else: | |
return [embeddings[:, 1:, :], masks[:, 1:], logits + contacts + attentions, | |
self.esm1_pooling(embeddings, masks)] | |
class DotProduct(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, queries, keys): | |
return torch.bmm(queries, keys.transpose(1, 2)) | |
class ScaledDotProduct(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, queries, keys): | |
return torch.bmm(queries, keys.transpose(1, 2)) / (queries.size(2) ** 0.5) | |
class GeneralDotProduct(nn.Module): | |
def __init__(self, hidden_dim): | |
super().__init__() | |
self.W = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) | |
torch.nn.init.orthogonal_(self.W) | |
def forward(self, queries, keys): | |
return torch.bmm(queries @ self.W, keys.transpose(1, 2)) | |
class ConcatDotProduct(nn.Module): | |
def __init__(self, hidden_dim): | |
super().__init__() | |
raise | |
def forward(self, queries, keys): | |
return | |
class Additive(nn.Module): | |
def __init__(self, hidden_dim): | |
super().__init__() | |
self.U = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) | |
self.T = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) | |
self.b = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1)) | |
self.W = nn.Sequential(nn.Tanh(), nn.Linear(hidden_dim, 1)) | |
torch.nn.init.orthogonal_(self.U) | |
torch.nn.init.orthogonal_(self.T) | |
def forward(self, queries, keys): | |
return self.W(queries.unsqueeze(1) @ self.U + keys.unsqueeze(2) @ self.T + self.b).squeeze(-1).transpose(1, 2) | |
class Attention(nn.Module): | |
def __init__(self, similarity, hidden_dim=1024, store_qk=False): | |
super().__init__() | |
self.softmax = nn.Softmax(dim=2) | |
self.attention_maps = [] | |
self.store_qk = store_qk | |
self.query_vectors, self.key_vectors = None, None | |
assert similarity in ['dot', 'scaled_dot', 'general_dot', 'concat_dot', 'additive'] | |
if similarity == 'dot': | |
self.similarity = DotProduct() | |
elif similarity == 'scaled_dot': | |
self.similarity = ScaledDotProduct() | |
elif similarity == 'general_dot': | |
self.similarity = GeneralDotProduct(hidden_dim) | |
elif similarity == 'concat_dot': | |
self.similarity = ConcatDotProduct(hidden_dim) | |
elif similarity == 'additive': | |
self.similarity = Additive(hidden_dim) | |
else: | |
raise | |
def release_qk(self): | |
Q, K = self.query_vectors, self.key_vectors | |
self.query_vectors, self.key_vectors = None, None | |
torch.cuda.empty_cache() | |
return Q, K | |
def forward(self, queries, keys, qmasks=None, kmasks=None): | |
if self.store_qk: | |
self.query_vectors = queries | |
self.key_vectors = keys | |
if torch.is_tensor(qmasks) and not torch.is_tensor(kmasks): | |
dim0, dim1 = qmasks.size(0), keys.size(1) | |
kmasks = torch.ones(dim0, dim1).cuda() | |
elif not torch.is_tensor(qmasks) and torch.is_tensor(kmasks): | |
dim0, dim1 = kmasks.size(0), queries.size(1) | |
qmasks = torch.ones(dim0, dim1).cuda() | |
else: | |
pass | |
attention = self.similarity(queries, keys) | |
if torch.is_tensor(qmasks) and torch.is_tensor(kmasks): | |
qmasks = qmasks.repeat(queries.size(0) // qmasks.size(0), 1).unsqueeze(2) | |
kmasks = kmasks.repeat(keys.size(0) // kmasks.size(0), 1).unsqueeze(2) | |
attnmasks = torch.bmm(qmasks, kmasks.transpose(1, 2)) | |
attention = torch.clip(attention, min=-10, max=10) | |
attention = attention.exp() | |
attention = attention * attnmasks | |
attention = attention / (attention.sum(2).unsqueeze(2) + 1e-5) | |
else: | |
attention = self.softmax(attention) | |
return attention | |
def save_attention_maps(self, input, output): | |
self.attention_maps.append(output.data.detach().cpu().numpy()) | |
class MultiheadAttention(nn.Module): | |
def __init__(self, d, h, sim='dot', analysis=False, store_qk=False): | |
super().__init__() | |
assert d % h == 0, f"{d} dimension, {h} heads" | |
self.h = h | |
p = d // h | |
self.project_queries = nn.Linear(d, d) | |
self.project_keys = nn.Linear(d, d) | |
self.project_values = nn.Linear(d, d) | |
self.concatenation = nn.Linear(d, d) | |
self.attention = Attention(sim, p, store_qk) | |
if analysis: | |
self.attention.register_forward_hook(save_attention_maps) | |
def release_qk(self): | |
Q, K = self.attention.release_qk() | |
Qb = Q.size(0) // self.h | |
Qn, Qd = Q.size(1), Q.size(2) | |
Kb = K.size(0) // self.h | |
Kn, Kd = K.size(1), K.size(2) | |
Q = Q.view(self.h, Qb, Qn, Qd) | |
K = K.view(self.h, Kb, Kn, Kd) | |
Q = Q.permute(1, 2, 0, 3).contiguous().view(Qb, Qn, Qd * self.h) | |
K = K.permute(1, 2, 0, 3).contiguous().view(Kb, Kn, Kd * self.h) | |
return Q, K | |
def forward(self, queries, keys, values, qmasks=None, kmasks=None): | |
h = self.h | |
b, n, d = queries.size() | |
_, m, _ = keys.size() | |
p = d // h | |
queries = self.project_queries(queries) # shape [b, n, d] | |
keys = self.project_keys(keys) # shape [b, m, d] | |
values = self.project_values(values) # shape [b, m, d] | |
queries = queries.view(b, n, h, p) | |
keys = keys.view(b, m, h, p) | |
values = values.view(b, m, h, p) | |
queries = queries.permute(2, 0, 1, 3).contiguous().view(h * b, n, p) | |
keys = keys.permute(2, 0, 1, 3).contiguous().view(h * b, m, p) | |
values = values.permute(2, 0, 1, 3).contiguous().view(h * b, m, p) | |
attn_w = self.attention(queries, keys, qmasks, kmasks) # shape [h * b, n, p] | |
output = torch.bmm(attn_w, values) | |
output = output.view(h, b, n, p) | |
output = output.permute(1, 2, 0, 3).contiguous().view(b, n, d) | |
output = self.concatenation(output) # shape [b, n, d] | |
return output, attn_w | |
class MultiheadAttentionExpanded(nn.Module): | |
def __init__(self, d, h, sim='dot', analysis=False): | |
super().__init__() | |
self.project_queries = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) | |
self.project_keys = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) | |
self.project_values = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) | |
self.concatenation = nn.Linear(h * d, d) | |
self.attention = Attention(sim, d) | |
if analysis: | |
self.attention.register_forward_hook(save_attention_maps) | |
def forward(self, queries, keys, values, qmasks=None, kmasks=None): | |
output = [] | |
for Wq, Wk, Wv in zip(self.project_queries, self.project_keys, self.project_values): | |
Pq, Pk, Pv = Wq(queries), Wk(keys), Wv(values) | |
output.append(torch.bmm(self.attention(Pq, Pk, qmasks, kmasks), Pv)) | |
output = self.concatenation(torch.cat(output, 1)) | |
return output | |
class EmptyModule(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
def forward(self, x): | |
return 0. | |
class RFF(nn.Module): | |
def __init__(self, h): | |
super().__init__() | |
self.rff = nn.Sequential(nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU()) | |
def forward(self, x): | |
return self.rff(x) | |
class MultiheadAttentionBlock(nn.Module): | |
def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): | |
super().__init__() | |
self.multihead = MultiheadAttention(d, h, similarity, analysis, | |
store_qk) if not full_head else MultiheadAttentionExpanded(d, h, similarity, | |
analysis) | |
self.layer_norm1 = nn.LayerNorm(d) | |
self.layer_norm2 = nn.LayerNorm(d) | |
self.rff = rff | |
def release_qk(self): | |
Q, K = self.multihead.release_qk() | |
return Q, K | |
def forward(self, x, y, xm=None, ym=None, layer_norm=True): | |
h, a = self.multihead(x, y, y, xm, ym) | |
if layer_norm: | |
h = self.layer_norm1(x + h) | |
return self.layer_norm2(h + self.rff(h)), a | |
else: | |
h = x + h | |
return h + self.rff(h), a | |
class SetAttentionBlock(nn.Module): | |
def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): | |
super().__init__() | |
self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) | |
def release_qk(self): | |
Q, K = self.mab.release_qk() | |
return Q, K | |
def forward(self, x, m=None, ln=True): | |
return self.mab(x, x, m, m, ln) | |
class InducedSetAttentionBlock(nn.Module): | |
def __init__(self, d, m, h, rff1, rff2, similarity='dot', full_head=False, analysis=False, store_qk=False): | |
super().__init__() | |
self.mab1 = MultiheadAttentionBlock(d, h, rff1, similarity, full_head, analysis, store_qk) | |
self.mab2 = MultiheadAttentionBlock(d, h, rff2, similarity, full_head, analysis, store_qk) | |
self.inducing_points = nn.Parameter(torch.randn(1, m, d)) | |
def release_qk(self): | |
raise NotImplemented | |
def forward(self, x, m=None, ln=True): | |
b = x.size(0) | |
p = self.inducing_points | |
p = p.repeat([b, 1, 1]) # shape [b, m, d] | |
h = self.mab1(p, x, None, m, ln) # shape [b, m, d] | |
return self.mab2(x, h, m, None, ln) | |
class PoolingMultiheadAttention(nn.Module): | |
def __init__(self, d, k, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): | |
super().__init__() | |
self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) | |
self.seed_vectors = nn.Parameter(torch.randn(1, k, d)) | |
torch.nn.init.xavier_uniform_(self.seed_vectors) | |
def release_qk(self): | |
Q, K = self.mab.release_qk() | |
return Q, K | |
def forward(self, z, m=None, ln=True): | |
b = z.size(0) | |
s = self.seed_vectors | |
s = s.repeat([b, 1, 1]) # random seed vector: shape [b, k, d] | |
return self.mab(s, z, None, m, ln) | |
class PoolingMultiheadCrossAttention(nn.Module): | |
def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): | |
super().__init__() | |
self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) | |
def release_qk(self): | |
Q, K = self.mab.release_qk() | |
return Q, K | |
def forward(self, X, Y, Xm=None, Ym=None, ln=True): | |
return self.mab(X, Y, Xm, Ym, ln) | |