|
import os |
|
import numpy as np |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertForMaskedLM |
|
import torch.distributed as dist |
|
import esm |
|
from torch.nn.utils.weight_norm import weight_norm |
|
|
|
|
|
""" |
|
functions and classes adapted from the following: |
|
1. https://keras.io/examples/vision/nl_image_search/ |
|
2. https://colab.research.google.com/drive/1hYHb0FTdKQCXZs3qCwVZnSuVGrZU2Z1w?usp=sharing |
|
""" |
|
|
|
class ProteinEncoder(nn.Module): |
|
""" |
|
Encoder for protein sequence to a fixed size vector --> z_s |
|
""" |
|
|
|
def __init__(self, args: any): |
|
super().__init__() |
|
|
|
|
|
self.seq_model_path = args.seq_model_path |
|
self.pretrained = args.pretrained_seq |
|
self.trainable = args.trainable_seq |
|
self.n_layers_to_finetune = args.pLM_n_layers_to_finetune |
|
self.rep_layer = args.rep_layer |
|
self.model, self.alphabet = self.get_ESM_model() |
|
|
|
for p in self.model.parameters(): |
|
if self.trainable and self.n_layers_to_finetune == 0: |
|
p.required_grad = True |
|
else: |
|
p.requires_grad = False |
|
|
|
|
|
if self.trainable and self.n_layers_to_finetune != 0: |
|
for layer in self.model.layers[-self.n_layers_to_finetune:]: |
|
for p in layer.parameters(): |
|
p.requires_grad = True |
|
|
|
|
|
|
|
self.target_token_idx = 0 |
|
|
|
def get_ESM_model(self): |
|
|
|
return esm.pretrained.load_model_and_alphabet( |
|
os.path.expanduser( |
|
self.seq_model_path |
|
) |
|
) |
|
|
|
def forward(self, x_s: torch.Tensor, compute_logits: bool=False): |
|
|
|
x_s = x_s.squeeze(1) |
|
|
|
outputs = self.model( |
|
x_s, |
|
repr_layers=[self.rep_layer], |
|
return_contacts=False |
|
) |
|
|
|
|
|
if compute_logits: |
|
logits = outputs['logits'] |
|
return logits |
|
|
|
|
|
cls_hidden = outputs['representations'][self.rep_layer][:,self.target_token_idx,:] |
|
return cls_hidden |
|
|
|
class TextEncoder(nn.Module): |
|
|
|
""" |
|
Encoder for protein's natural text to a fixed size vector --> z_t |
|
""" |
|
|
|
def __init__(self, args: any): |
|
super().__init__() |
|
|
|
self.model_name = args.text_model_path |
|
self.pretrained = args.pretrained_text |
|
self.trainable = args.trainable_text |
|
self.n_layers_to_finetune = args.bLM_n_layers_to_finetune |
|
self.tokenizer = AutoTokenizer.from_pretrained(args.text_model_path) |
|
|
|
if self.pretrained: |
|
|
|
self.model = BertForMaskedLM.from_pretrained(self.model_name) |
|
|
|
else: |
|
|
|
self.model = BertForMaskedLM.from_config(self.model_name) |
|
|
|
for p in self.model.parameters(): |
|
if self.trainable and self.n_layers_to_finetune == 0: |
|
p.required_grad = True |
|
else: |
|
p.requires_grad = False |
|
|
|
|
|
if self.trainable and self.n_layers_to_finetune != 0: |
|
for layer in self.model.bert.encoder.layer[-self.n_layers_to_finetune:]: |
|
for p in layer.parameters(): |
|
p.requires_grad = True |
|
|
|
|
|
|
|
self.target_token_idx = 0 |
|
|
|
def forward(self, inputs: torch.Tensor, compute_logits: bool=False) -> torch.Tensor: |
|
|
|
inputs = inputs.squeeze(1) |
|
|
|
if compute_logits: |
|
|
|
|
|
outputs = self.model(inputs) |
|
logits = outputs.logits |
|
return logits |
|
|
|
else: |
|
outputs = self.model(inputs, output_hidden_states=True) |
|
|
|
last_hidden_state = outputs.hidden_states[-1] |
|
return last_hidden_state[:, self.target_token_idx, :] |
|
|
|
|
|
|
|
class ProjectionHead(nn.Module): |
|
""" |
|
g(.) which maps z_t --> h_t or z_s --> h_s |
|
|
|
Note: h is the joint embedding representation, h_t |
|
is the joint embedding for the text caption, and |
|
h_s is the joint embedding for the protein sequence. |
|
""" |
|
|
|
def __init__(self, embedding_dim: int, args: any): |
|
|
|
super().__init__() |
|
self.projection_dim = args.proj_embedding_dim |
|
self.dropout = args.dropout |
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
self.projection = nn.Linear(self.embedding_dim, self.projection_dim) |
|
self.gelu = nn.GELU() |
|
self.fc = nn.Linear(self.projection_dim, self.projection_dim) |
|
self.dropout = nn.Dropout(self.dropout) |
|
self.layer_norm = nn.LayerNorm(self.projection_dim) |
|
|
|
def forward(self, z: torch.Tensor) -> torch.Tensor: |
|
|
|
projection = self.projection(z) |
|
h = self.gelu(projection) |
|
h = self.fc(h) |
|
h = self.dropout(h) |
|
h = h + projection |
|
h = self.layer_norm(h) |
|
return h |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class pfam_PEN_CL(nn.Module): |
|
|
|
""" |
|
Protein Embeddings with Natural lanauge using Constrastive Learing (PEN-CL) while including pfam constrastive learning. |
|
""" |
|
|
|
def __init__(self, args: any): |
|
|
|
super().__init__() |
|
|
|
self.protein_embedding = args.protein_encoder_embedding |
|
self.text_embedding = args.text_encoder_embedding |
|
self.temperature = args.temperature |
|
|
|
|
|
self.protein_encoder = ProteinEncoder(args=args) |
|
|
|
self.text_encoder = TextEncoder(args=args) |
|
|
|
|
|
self.protein_projection = ProjectionHead( |
|
embedding_dim=self.protein_embedding, |
|
args=args |
|
) |
|
|
|
|
|
self.text_projection = ProjectionHead( |
|
embedding_dim=self.text_embedding, |
|
args=args |
|
) |
|
|
|
def forward( |
|
self, |
|
x_t: torch.Tensor, |
|
x_s: torch.Tensor, |
|
compute_masked_logits: bool=False |
|
) -> dict: |
|
|
|
if compute_masked_logits: |
|
|
|
protein_logits = self.protein_encoder(x_s, compute_logits=True) |
|
text_logits = self.text_encoder(x_t, compute_logits=True) |
|
|
|
return { |
|
'text_masked_logits': text_logits, |
|
'protein_masked_logits': protein_logits |
|
} |
|
|
|
else: |
|
|
|
|
|
z_t = self.text_encoder(x_t, compute_logits=False) |
|
z_s = self.protein_encoder(x_s, compute_logits=False) |
|
|
|
|
|
z_t_joint = self.text_projection(z_t) |
|
z_s_joint = self.protein_projection(z_s) |
|
|
|
return { |
|
'text_joint_latent': z_t_joint, |
|
'seq_joint_latent': z_s_joint, |
|
} |
|
|
|
def compute_inter_loss( |
|
self, |
|
protein_embeddings: torch.Tensor, |
|
text_embeddings: torch.Tensor, |
|
batch_size: int |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
""" |
|
Compute the inter-modal contrastive InfoNCE loss between protein and text embeddings. |
|
|
|
Parameters: |
|
- protein_embeddings: A tensor representing the embeddings of the protein sequences. |
|
- text_embeddings: A tensor representing the embeddings of the text descriptions. |
|
- batch_size: The number of samples in the batch. |
|
|
|
Steps: |
|
1. Generate a masking matrix to identify off-diagonal elements. |
|
2. Compute cosine similarities (i.e., logits) between text and protein embeddings. |
|
3. Compute self-similarities for both protein and text embeddings. |
|
4. Mask off-diagonal elements between swiss-prot and pfam in the similarity matrices. |
|
5. Define ground truth by averaging the masked protein and text similarity matrices. |
|
6. Compute the contrastive loss for the protein and text embeddings using the ground truth. |
|
|
|
Returns: |
|
- Mean contrastive loss for the given batch of protein and text embeddings. |
|
- The logits (cosine similarity matrix between text and protein embeddings). |
|
|
|
Note: This function assumes a specific structure in the input batches, where corresponding positive samples |
|
in the protein and text embeddings are arranged in a particular way, allowing for masking and contrastive loss calculation. |
|
""" |
|
|
|
|
|
mask = torch.zeros((2*batch_size, 2*batch_size)) |
|
|
|
mask[batch_size:, :batch_size] = torch.eye(batch_size) |
|
|
|
mask[:batch_size, batch_size:] = torch.eye(batch_size) |
|
|
|
mask = mask.to(protein_embeddings.device).bool() |
|
|
|
|
|
logits = (text_embeddings @ protein_embeddings.T) / self.temperature |
|
protein_similarity = protein_embeddings @ protein_embeddings.T |
|
text_similarity = text_embeddings @ text_embeddings.T |
|
|
|
|
|
mask_protein_similarity = self.set_inf(protein_similarity, mask) |
|
mask_text_similarity = self.set_inf(text_similarity, mask) |
|
mask_logits = self.set_inf(logits, mask) |
|
|
|
|
|
targets = F.softmax( |
|
(mask_protein_similarity + mask_text_similarity) / (2 * self.temperature), dim=-1 |
|
) |
|
|
|
|
|
text_loss = self.cross_entropy(mask_logits, targets, reduction='none') |
|
protein_loss = self.cross_entropy(mask_logits.T, targets.T, reduction='none') |
|
loss = (protein_loss + text_loss) / 2.0 |
|
|
|
return ( |
|
loss.mean(), |
|
mask_logits.detach().cpu() |
|
) |
|
|
|
|
|
def compute_intra_loss( |
|
self, |
|
protein_embeddings, |
|
batch_size |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor, |
|
): |
|
""" |
|
Compute the intra-modal contrastive InfoNCE loss for protein embeddings. |
|
|
|
Parameters: |
|
- protein_embeddings: A tensor representing the embeddings of the protein sequences. |
|
- batch_size: Batch size used for training. |
|
|
|
Steps: |
|
1. Normalize the protein embeddings using L2 normalization. |
|
2. Compute the cosine similarity between the normalized embeddings. |
|
3. Mask the diagonal of the cosine similarity matrix to avoid using a protein's similarity with itself. |
|
4. Define positive examples by rolling the mask. The positive example for a given protein embedding is determined by an embedding half the batch size away. |
|
5. Compute the InfoNCE loss using the masked cosine similarity matrix. |
|
|
|
Returns: |
|
- Mean InfoNCE loss for the given batch of protein embeddings. |
|
- The cosine similarity matrix. |
|
|
|
Note: The underlying assumption is that in each batch, corresponding positive samples for a given protein embedding |
|
lie half the batch size away. The function computes the negative log likelihood loss between these positive samples |
|
and the entire batch. |
|
""" |
|
|
|
|
|
|
|
norm_protein_embeddings = protein_embeddings |
|
|
|
|
|
cosine_similarity = (norm_protein_embeddings @ norm_protein_embeddings.T) / self.temperature |
|
|
|
|
|
sample_size = protein_embeddings.shape[0] |
|
mask = torch.eye(sample_size, device=cosine_similarity.device, dtype=torch.bool) |
|
|
|
cosine_similarity = self.set_inf(cosine_similarity, mask) |
|
|
|
|
|
pos_mask = mask.roll(shifts=mask.shape[0]//2, dims=0) |
|
|
|
|
|
nll = -cosine_similarity[pos_mask] + torch.logsumexp(cosine_similarity, dim=-1) |
|
|
|
return ( |
|
nll.mean(), |
|
cosine_similarity.cpu(), |
|
) |
|
|
|
def set_inf( |
|
self, |
|
tensor: torch.Tensor, |
|
mask: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
if tensor.dtype == torch.float32: |
|
replace_value = -9e15 |
|
elif tensor.dtype == torch.float16: |
|
replace_value = -1e4 |
|
else: |
|
raise ValueError("Unsupported tensor dtype for this operation.") |
|
|
|
|
|
tensor.masked_fill_(mask, replace_value) |
|
|
|
return tensor |
|
|
|
def cross_entropy( |
|
self, |
|
preds: torch.Tensor, |
|
targets: torch.Tensor, |
|
reduction: str='none' |
|
) -> torch.Tensor: |
|
|
|
|
|
log_softmax = nn.LogSoftmax(dim=-1) |
|
loss = (-targets * log_softmax(preds)).sum(1) |
|
|
|
if reduction == 'none': |
|
return loss |
|
elif reduction == 'mean': |
|
return loss.mean() |
|
else: |
|
assert False, print('Choose either "none" or "mean" for reduction argument') |
|
|
|
def compute_masked_lang_loss( |
|
self, |
|
logits_masked: torch.Tensor, |
|
targets: torch.Tensor, |
|
targets_masked: torch.Tensor, |
|
mask_token_id: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
""" |
|
Compute the masked language model loss for BERT-like architectures. |
|
|
|
Given a batch of logits predicted for masked positions and their corresponding target tokens, this function |
|
computes the cross-entropy loss between the predicted logits and the true labels, but only for positions |
|
that have been masked in the input. |
|
|
|
Parameters: |
|
- logits_masked: Predicted token logits for masked positions from the model. |
|
Shape: (batch_size, seq_len, vocab_size). |
|
- targets: True token IDs for each position in the input sequence. |
|
Shape: (batch_size, seq_len). |
|
- targets_masked: Token IDs for the input sequence, including masked positions. |
|
Shape: (batch_size, seq_len). |
|
- mask_token_id: The ID corresponding to the [MASK] token in the vocabulary. |
|
|
|
Steps: |
|
1. Compute the cross-entropy loss between predicted logits and true labels across all positions. |
|
2. For each sample in the batch, locate the positions that were masked. |
|
3. Extract the loss values corresponding to these masked positions. |
|
4. Compute and return the mean of these extracted loss values across the batch. |
|
|
|
Returns: |
|
- Mean cross-entropy loss for masked positions across the batch. |
|
|
|
Note: This function focuses exclusively on masked positions in the input, as is typical for the MLM objective |
|
in BERT-like models. It disregards unmasked positions. |
|
""" |
|
|
|
|
|
loss_func = nn.CrossEntropyLoss(reduction='none') |
|
loss_mask = loss_func( |
|
logits_masked.permute(0, 2, 1), |
|
targets.squeeze(1) |
|
) |
|
|
|
|
|
batch_loss = [] |
|
|
|
for ii, target_mask_sample in enumerate(targets_masked): |
|
|
|
|
|
masked_positions = (target_mask_sample == mask_token_id).tolist() |
|
|
|
loss_mask_sample = loss_mask[ii][masked_positions] |
|
|
|
|
|
if loss_mask_sample.numel() > 0: |
|
batch_loss.append(torch.mean(loss_mask_sample).unsqueeze(0)) |
|
|
|
if len(loss_mask_sample) > 0: |
|
loss_mask_mean = torch.mean(torch.cat(batch_loss)) |
|
else: |
|
|
|
loss_mask_mean = torch.tensor(0.0, device=logits_masked.device) |
|
|
|
return loss_mask_mean |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Facilitator(nn.Module): |
|
|
|
def __init__(self, |
|
in_dim: int, |
|
hid_dim: int, |
|
out_dim: int, |
|
dropout: float = 0. |
|
): |
|
super().__init__() |
|
|
|
|
|
self.main = nn.Sequential( |
|
weight_norm(nn.Linear(in_dim, hid_dim), dim=None), |
|
nn.GELU(), |
|
nn.Dropout(dropout, inplace=True), |
|
weight_norm(nn.Linear(hid_dim, out_dim), dim=None) |
|
) |
|
|
|
def forward(self, x): |
|
|
|
return self.main(x) |
|
|
|
def compute_loss(self, output: torch.Tensor, target: torch.Tensor, loss_option='MSE') -> torch.Tensor: |
|
|
|
if loss_option == 'MSE': |
|
return Facilitator.compute_MSE(output, target) |
|
elif loss_option == 'MMD': |
|
return Facilitator.compute_mmd(output, target) |
|
else: |
|
return ValueError("Invalid loss option") |
|
|
|
@staticmethod |
|
def compute_MSE(output, target): |
|
|
|
mse_loss = nn.MSELoss() |
|
loss = mse_loss(output, target) |
|
return loss |
|
|
|
@staticmethod |
|
def compute_kernel( |
|
x: torch.FloatTensor, |
|
y: torch.FloatTensor |
|
) -> torch.FloatTensor: |
|
""" |
|
Compute the Gaussian RBF kernel between tensors x and y |
|
""" |
|
|
|
|
|
x_size, y_size = x.shape[0], y.shape[0] |
|
|
|
|
|
dim = x.shape[1] |
|
|
|
x = x.view(x_size, 1, dim) |
|
y = y.view(1, y_size, dim) |
|
|
|
x_core = x.expand(x_size, y_size, dim) |
|
y_core = y.expand(x_size, y_size, dim) |
|
|
|
|
|
return torch.exp(-(x_core - y_core).pow(2).mean(2) / dim) |
|
|
|
@staticmethod |
|
def compute_mmd( |
|
x: torch.FloatTensor, |
|
y: torch.FloatTensor |
|
) -> torch.FloatTensor: |
|
""" |
|
Compute the Maximum Mean Discrepancy (MMD) between two distributions. |
|
Args: |
|
x: Samples from first distribution (z_t_to_p ~ q(z_p)) |
|
y: Samples from second distribution (z_p ~ p(z_p)) |
|
Returns: |
|
MMD_loss: The MMD loss between the sampled distributions |
|
""" |
|
|
|
x_kernel = Facilitator.compute_kernel(x, x) |
|
y_kernel = Facilitator.compute_kernel(y, y) |
|
xy_kernel = Facilitator.compute_kernel(x, y) |
|
|
|
|
|
return x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean() |
|
|
|
|
|
|