|
|
|
import torch |
|
from torch import nn, optim |
|
from torch.nn import functional as F |
|
import torch.distributed as dist |
|
|
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning import Trainer, seed_everything |
|
|
|
|
|
import itertools |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import sys |
|
from tqdm import tqdm |
|
import time |
|
|
|
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
|
|
|
|
|
import Stage1_source.helper_funcs as helper_tools |
|
import Stage1_source.preprocess as prep |
|
import Stage1_source.model as mod |
|
|
|
|
|
|
|
|
|
|
|
|
|
class PL_PEN_CL(pl.LightningModule): |
|
|
|
|
|
def __init__( |
|
self, |
|
args: any, |
|
model: nn.Module, |
|
text_tokenizer: any, |
|
sequence_tokenizer: any |
|
): |
|
|
|
super().__init__() |
|
|
|
self.script_args = args |
|
|
|
|
|
self.model = model |
|
|
|
|
|
self.text_tokenizer = text_tokenizer |
|
self.sequence_tokenizer = sequence_tokenizer |
|
|
|
|
|
self.val_text_joint_latents = [] |
|
self.val_seq_joint_latents = [] |
|
|
|
|
|
self.predict_text_joint_latents = [] |
|
self.predict_seq_joint_latents = [] |
|
|
|
def forward( |
|
self, |
|
x_t: torch.Tensor, |
|
x_s: torch.Tensor |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
outputs = self.model( |
|
x_t=x_t, |
|
x_s=x_s |
|
) |
|
|
|
return ( |
|
outputs['text_joint_latent'], |
|
outputs['seq_joint_latent'], |
|
) |
|
|
|
def training_step( |
|
self, |
|
batch: torch.Tensor, |
|
batch_idx: any, |
|
) -> dict: |
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch = batch |
|
|
|
|
|
z_t, z_s = self( |
|
x_t=text_batch, |
|
x_s=protein_batch |
|
) |
|
dist.barrier() |
|
|
|
|
|
z_t_all = self.all_gather(z_t, sync_grads=True) |
|
dist.barrier() |
|
z_s_all = self.all_gather(z_s, sync_grads=True) |
|
|
|
|
|
z_t_all = z_t_all.view(-1, z_t.shape[-1]) |
|
z_s_all = z_s_all.view(-1, z_s.shape[-1]) |
|
|
|
|
|
loss, logits = self.model.compute_loss( |
|
protein_embeddings=z_s_all, |
|
text_embeddings=z_t_all |
|
) |
|
|
|
|
|
self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
|
|
metric_dict = self.performance_metrics(logits=logits) |
|
for key in metric_dict: |
|
values = metric_dict[key] |
|
|
|
final_key = 'train_' + key |
|
self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
if batch_idx == 0: |
|
gpu_memory_usage = helper_tools.print_gpu_initialization() |
|
self.log(f'gpu_memory_usage', gpu_memory_usage, sync_dist=True) |
|
|
|
return {'loss': loss} |
|
|
|
|
|
def validation_step( |
|
self, |
|
batch: list, |
|
batch_idx: any |
|
) -> dict: |
|
|
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch = batch |
|
|
|
|
|
z_t, z_s = self( |
|
x_t=text_batch, |
|
x_s=protein_batch |
|
) |
|
|
|
dist.barrier() |
|
|
|
z_t_all = self.all_gather(z_t, sync_grads=True).view(-1, z_t.shape[-1]) |
|
dist.barrier() |
|
z_s_all = self.all_gather(z_s, sync_grads=True).view(-1, z_s.shape[-1]) |
|
|
|
|
|
z_t_all = z_t_all.view(-1, z_t.shape[-1]) |
|
z_s_all = z_s_all.view(-1, z_s.shape[-1]) |
|
|
|
|
|
loss, logits = self.model.compute_loss( |
|
protein_embeddings=z_s_all, |
|
text_embeddings=z_t_all |
|
) |
|
|
|
|
|
|
|
self.log('valid_loss', loss, prog_bar=True, sync_dist=True) |
|
|
|
|
|
metric_dict = self.performance_metrics(logits=logits.detach().cpu()) |
|
|
|
for key in metric_dict: |
|
values = metric_dict[key] |
|
final_key = 'valid_' + key |
|
self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, sync_dist=True) |
|
|
|
|
|
self.val_text_joint_latents.append(z_t_all.detach().cpu()) |
|
self.val_seq_joint_latents.append(z_s_all.detach().cpu()) |
|
|
|
return {'valid_loss': loss} |
|
|
|
def on_validation_epoch_end(self): |
|
|
|
|
|
val_z_t_joint = torch.cat(self.val_text_joint_latents, dim=0) |
|
val_z_s_joint = torch.cat(self.val_seq_joint_latents, dim=0) |
|
|
|
|
|
text_log_sigma_k, S_text = self.compute_singular(val_z_t_joint.detach().cpu()) |
|
protein_log_sigma_k, S_protein = self.compute_singular(val_z_s_joint.detach().cpu()) |
|
|
|
|
|
self.save_png_to_tensorboard( |
|
data=text_log_sigma_k.numpy(), |
|
title='text', |
|
) |
|
self.save_png_to_tensorboard( |
|
data=protein_log_sigma_k.numpy(), |
|
title='protein' |
|
) |
|
|
|
|
|
self.val_text_joint_latents.clear() |
|
self.val_seq_joint_latents.clear() |
|
|
|
|
|
|
|
erank_text = self.compute_effective_rank(sigma_ks=S_text) |
|
erank_protein = self.compute_effective_rank(sigma_ks=S_protein) |
|
|
|
|
|
self.log('valid_erank_text', erank_text, sync_dist=True) |
|
self.log('valid_erank_protein', erank_protein, sync_dist=True) |
|
|
|
|
|
def configure_optimizers(self,): |
|
|
|
params = [ |
|
{"params": self.model.protein_encoder.parameters(), "lr": self.script_args.protein_encoder_lr}, |
|
{"params": self.model.text_encoder.parameters(), "lr": self.script_args.text_encoder_lr}, |
|
{"params": itertools.chain( |
|
self.model.protein_projection.parameters(), |
|
self.model.text_projection.parameters() |
|
), |
|
"lr": self.script_args.head_lr, |
|
"weight_decay": self.script_args.weight_decay} |
|
] |
|
|
|
optimizer = torch.optim.AdamW(params, weight_decay=self.script_args.weight_decay) |
|
|
|
return { |
|
"optimizer": optimizer, |
|
} |
|
|
|
@torch.no_grad() |
|
def compute_class_metrics( |
|
self, |
|
outputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
source: str |
|
) -> dict: |
|
|
|
|
|
outputs_np = outputs.numpy() |
|
targets_np = targets.numpy() |
|
|
|
|
|
accuracy = accuracy_score(targets_np, outputs_np.round()) |
|
precision = precision_score(targets_np, outputs_np.round(), average='micro') |
|
recall = recall_score(targets_np, outputs_np.round(), average='micro') |
|
f1 = f1_score(targets_np, outputs_np.round(), average='micro') |
|
|
|
return { |
|
f'{source}_accuracy': accuracy, |
|
f'{source}_precision': precision, |
|
f'{source}_recall': recall, |
|
f'{source}_f1': f1 |
|
} |
|
|
|
@torch.no_grad() |
|
def performance_metrics(self, logits: torch.Tensor) -> tuple: |
|
|
|
logits = logits.cpu().float() |
|
|
|
|
|
p_text = F.softmax(logits, dim=-1) |
|
p_seq = F.softmax(logits.T, dim=-1) |
|
p_tot = (p_seq + p_text) / 2 |
|
|
|
|
|
y_pred_text = torch.argmax(p_text, dim=-1) |
|
y_pred_seq = torch.argmax(p_seq, dim=-1) |
|
y_pred = torch.argmax(p_tot, dim=-1) |
|
y_true = torch.arange(y_pred_text.shape[0]) |
|
|
|
|
|
text_metrics = self.compute_class_metrics( |
|
outputs=y_pred_text, |
|
targets=y_true, |
|
source='text' |
|
) |
|
seq_metrics = self.compute_class_metrics( |
|
outputs=y_pred_seq, |
|
targets=y_true, |
|
source='seq' |
|
) |
|
total_metrics = self.compute_class_metrics( |
|
outputs=y_pred, |
|
targets=y_true, |
|
source='total' |
|
) |
|
|
|
|
|
combined_dict = {} |
|
combined_dict.update(text_metrics) |
|
combined_dict.update(seq_metrics) |
|
combined_dict.update(total_metrics) |
|
|
|
return combined_dict |
|
|
|
@torch.no_grad() |
|
def compute_singular(self, inputs: torch.Tensor) -> ( |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
|
|
|
|
|
|
mean_inputs = torch.mean(inputs, dim=0) |
|
norm_inputs = inputs - mean_inputs |
|
|
|
|
|
C = torch.zeros((norm_inputs.shape[-1], norm_inputs.shape[-1])) |
|
for sample_idx in range(norm_inputs.shape[0]): |
|
norm_vector = norm_inputs[sample_idx, :].unsqueeze(0) |
|
C += norm_vector.T @ norm_vector |
|
C *= 1/norm_vector.shape[0] |
|
|
|
_, S, _ = torch.linalg.svd(C, full_matrices=False) |
|
|
|
|
|
log_sigma_k, _ = torch.sort(torch.log(S), descending=True) |
|
return ( |
|
log_sigma_k, |
|
S |
|
) |
|
|
|
def compute_effective_rank(self, sigma_ks: torch.Tensor) -> torch.Tensor: |
|
""" |
|
references: |
|
- Roy et al. The effective rank: a measure of effective dimensionality |
|
- Garrido et al. RankMe: Assessing the Downstream Performnace of Pretrained SS Reps by their Rank. |
|
""" |
|
|
|
sigma_ks, _ = torch.sort(sigma_ks, descending=True) |
|
|
|
|
|
l1_norm_sigma = torch.norm(sigma_ks, p=1) |
|
|
|
|
|
p_k = sigma_ks / l1_norm_sigma + torch.finfo(torch.float).eps |
|
|
|
|
|
entropy = - torch.sum(p_k * torch.log(p_k)) |
|
|
|
|
|
erank = torch.exp(entropy) |
|
|
|
return erank |
|
|
|
def save_png_to_tensorboard( |
|
self, |
|
data: np.single, |
|
title: str, |
|
x_axis_label: str='Singular Value Rank Index', |
|
y_axis_label: str='Log of singular values', |
|
): |
|
|
|
current_epoch = self.trainer.current_epoch |
|
|
|
|
|
fig, ax = plt.subplots(dpi=300) |
|
ax.plot(data) |
|
ax.set_xlabel(x_axis_label) |
|
ax.set_ylabel(y_axis_label) |
|
ax.set_title(title) |
|
ax.set_ylim([-25,3]) |
|
|
|
|
|
self.logger.experiment.add_figure(f'{title}_SingularValues_{current_epoch}', fig, current_epoch) |
|
|
|
def predict_step( |
|
self, |
|
batch: torch.Tensor, |
|
batch_idx: torch.Tensor, |
|
dataloder_idx: bool=False |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch = batch |
|
outputs = self( |
|
x_t=text_batch, |
|
x_s=protein_batch, |
|
) |
|
|
|
z_t_joint, z_p_joint = outputs |
|
|
|
self.predict_text_joint_latents.append(z_t_joint.detach().cpu()) |
|
self.predict_seq_joint_latents.append(z_p_joint.detach().cpu()) |
|
|
|
return outputs |
|
|
|
def on_predict_epoch_end(self, outputs=None): |
|
|
|
self.predict_text_joint_latents = torch.cat(self.predict_text_joint_latents).cpu() |
|
self.predict_seq_joint_latents = torch.cat(self.predict_seq_joint_latents).cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class mask_PL_PEN_CL(pl.LightningModule): |
|
|
|
|
|
def __init__( |
|
self, |
|
args: any, |
|
model: nn.Module, |
|
text_tokenizer: any, |
|
sequence_tokenizer: any |
|
): |
|
|
|
super().__init__() |
|
|
|
self.script_args = args |
|
|
|
|
|
self.model = model |
|
|
|
|
|
self.text_tokenizer = text_tokenizer |
|
self.sequence_tokenizer = sequence_tokenizer |
|
|
|
|
|
self.val_text_joint_latents = [] |
|
self.val_seq_joint_latents = [] |
|
|
|
|
|
self.predict_text_joint_latents = [] |
|
self.predict_seq_joint_latents = [] |
|
|
|
def forward( |
|
self, |
|
x_t: torch.Tensor, |
|
x_s: torch.Tensor, |
|
compute_masked_logits: bool=False |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
outputs = self.model( |
|
x_t=x_t, |
|
x_s=x_s, |
|
compute_masked_logits=compute_masked_logits |
|
) |
|
|
|
if compute_masked_logits: |
|
|
|
return ( |
|
outputs['text_masked_logits'], |
|
outputs['protein_masked_logits'] |
|
) |
|
else: |
|
|
|
return ( |
|
outputs['text_joint_latent'], |
|
outputs['seq_joint_latent'], |
|
) |
|
|
|
def training_step( |
|
self, |
|
batch: torch.Tensor, |
|
batch_idx: any, |
|
) -> dict: |
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch, text_mask_batch, protein_mask_batch = batch |
|
|
|
|
|
z_t, z_s = self( |
|
x_t=text_batch, |
|
x_s=protein_batch, |
|
compute_masked_logits=False |
|
) |
|
dist.barrier() |
|
|
|
|
|
z_t_all = self.all_gather(z_t, sync_grads=True) |
|
dist.barrier() |
|
z_s_all = self.all_gather(z_s, sync_grads=True) |
|
|
|
|
|
z_t_all = z_t_all.view(-1, z_t.shape[-1]) |
|
z_s_all = z_s_all.view(-1, z_s.shape[-1]) |
|
|
|
|
|
loss_align, logits = self.model.compute_loss( |
|
protein_embeddings=z_s_all, |
|
text_embeddings=z_t_all |
|
) |
|
|
|
|
|
logits_t_mask, logits_s_mask = self( |
|
x_t=text_mask_batch, |
|
x_s=protein_mask_batch, |
|
compute_masked_logits=True |
|
) |
|
|
|
|
|
loss_text_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_t_mask, |
|
targets=text_batch, |
|
targets_masked=text_mask_batch, |
|
mask_token_id=self.text_tokenizer.mask_token_id |
|
) |
|
|
|
|
|
loss_sequence_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_s_mask, |
|
targets=protein_batch, |
|
targets_masked=protein_mask_batch, |
|
mask_token_id=self.sequence_tokenizer.mask_idx |
|
) |
|
|
|
|
|
|
|
loss = loss_align + loss_text_mask + loss_sequence_mask |
|
|
|
|
|
self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('train_loss_align', loss_align, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('train_loss_text_mask', loss_text_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('train_loss_seq_mask', loss_sequence_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
|
|
metric_dict = self.performance_metrics(logits=logits) |
|
for key in metric_dict: |
|
values = metric_dict[key] |
|
|
|
final_key = 'train_' + key |
|
self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
if batch_idx == 0: |
|
gpu_memory_usage = helper_tools.print_gpu_initialization() |
|
self.log(f'gpu_memory_usage', gpu_memory_usage, sync_dist=True) |
|
|
|
return {'loss': loss} |
|
|
|
|
|
def validation_step( |
|
self, |
|
batch: list, |
|
batch_idx: any |
|
) -> dict: |
|
|
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch, text_mask_batch, protein_mask_batch = batch |
|
|
|
|
|
z_t, z_s = self( |
|
x_t=text_batch, |
|
x_s=protein_batch |
|
) |
|
|
|
dist.barrier() |
|
|
|
z_t_all = self.all_gather(z_t, sync_grads=True).view(-1, z_t.shape[-1]) |
|
dist.barrier() |
|
z_s_all = self.all_gather(z_s, sync_grads=True).view(-1, z_s.shape[-1]) |
|
|
|
|
|
z_t_all = z_t_all.view(-1, z_t.shape[-1]) |
|
z_s_all = z_s_all.view(-1, z_s.shape[-1]) |
|
|
|
|
|
loss_align, logits = self.model.compute_loss( |
|
protein_embeddings=z_s_all, |
|
text_embeddings=z_t_all |
|
) |
|
|
|
|
|
logits_t_mask, logits_s_mask = self( |
|
x_t=text_mask_batch, |
|
x_s=protein_mask_batch, |
|
compute_masked_logits=True |
|
) |
|
|
|
|
|
loss_text_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_t_mask, |
|
targets=text_batch, |
|
targets_masked=text_mask_batch, |
|
mask_token_id=self.text_tokenizer.mask_token_id |
|
) |
|
|
|
|
|
loss_sequence_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_s_mask, |
|
targets=protein_batch, |
|
targets_masked=protein_mask_batch, |
|
mask_token_id=self.sequence_tokenizer.mask_idx |
|
) |
|
|
|
|
|
loss = loss_align + loss_text_mask + loss_sequence_mask |
|
|
|
|
|
self.log('valid_loss', loss, prog_bar=True, sync_dist=True) |
|
self.log('valid_loss_align', loss_align, prog_bar=True, sync_dist=True) |
|
self.log('valid_loss_text_mask', loss_text_mask, prog_bar=False, sync_dist=True) |
|
self.log('valid_loss_seq_mask', loss_sequence_mask, prog_bar=False, sync_dist=True) |
|
|
|
|
|
metric_dict = self.performance_metrics(logits=logits.detach().cpu()) |
|
|
|
for key in metric_dict: |
|
values = metric_dict[key] |
|
final_key = 'valid_' + key |
|
self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, sync_dist=True) |
|
|
|
|
|
self.val_text_joint_latents.append(z_t_all.detach().cpu()) |
|
self.val_seq_joint_latents.append(z_s_all.detach().cpu()) |
|
|
|
return {'valid_loss': loss} |
|
|
|
def on_validation_epoch_end(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.val_text_joint_latents.clear() |
|
self.val_seq_joint_latents.clear() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self,): |
|
|
|
params = [ |
|
{"params": self.model.protein_encoder.parameters(), "lr": self.script_args.protein_encoder_lr}, |
|
{"params": self.model.text_encoder.parameters(), "lr": self.script_args.text_encoder_lr}, |
|
{"params": itertools.chain( |
|
self.model.protein_projection.parameters(), |
|
self.model.text_projection.parameters() |
|
), |
|
"lr": self.script_args.head_lr, |
|
"weight_decay": self.script_args.weight_decay} |
|
] |
|
|
|
optimizer = torch.optim.AdamW(params, weight_decay=self.script_args.weight_decay) |
|
|
|
return { |
|
"optimizer": optimizer, |
|
} |
|
|
|
@torch.no_grad() |
|
def compute_class_metrics( |
|
self, |
|
outputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
source: str |
|
) -> dict: |
|
|
|
|
|
outputs_np = outputs.numpy() |
|
targets_np = targets.numpy() |
|
|
|
|
|
accuracy = accuracy_score(targets_np, outputs_np.round()) |
|
precision = precision_score(targets_np, outputs_np.round(), average='micro') |
|
recall = recall_score(targets_np, outputs_np.round(), average='micro') |
|
f1 = f1_score(targets_np, outputs_np.round(), average='micro') |
|
|
|
return { |
|
f'{source}_accuracy': accuracy, |
|
f'{source}_precision': precision, |
|
f'{source}_recall': recall, |
|
f'{source}_f1': f1 |
|
} |
|
|
|
@torch.no_grad() |
|
def performance_metrics(self, logits: torch.Tensor) -> tuple: |
|
|
|
logits = logits.cpu().float() |
|
|
|
|
|
p_text = F.softmax(logits, dim=-1) |
|
p_seq = F.softmax(logits.T, dim=-1) |
|
p_tot = (p_seq + p_text) / 2 |
|
|
|
|
|
y_pred_text = torch.argmax(p_text, dim=-1) |
|
y_pred_seq = torch.argmax(p_seq, dim=-1) |
|
y_pred = torch.argmax(p_tot, dim=-1) |
|
y_true = torch.arange(y_pred_text.shape[0]) |
|
|
|
|
|
text_metrics = self.compute_class_metrics( |
|
outputs=y_pred_text, |
|
targets=y_true, |
|
source='text' |
|
) |
|
seq_metrics = self.compute_class_metrics( |
|
outputs=y_pred_seq, |
|
targets=y_true, |
|
source='seq' |
|
) |
|
total_metrics = self.compute_class_metrics( |
|
outputs=y_pred, |
|
targets=y_true, |
|
source='total' |
|
) |
|
|
|
|
|
combined_dict = {} |
|
combined_dict.update(text_metrics) |
|
combined_dict.update(seq_metrics) |
|
combined_dict.update(total_metrics) |
|
|
|
return combined_dict |
|
|
|
@torch.no_grad() |
|
def compute_singular(self, inputs: torch.Tensor) -> ( |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
|
|
|
|
|
|
mean_inputs = torch.mean(inputs, dim=0) |
|
norm_inputs = inputs - mean_inputs |
|
|
|
|
|
C = torch.zeros((norm_inputs.shape[-1], norm_inputs.shape[-1])) |
|
for sample_idx in tqdm(range(norm_inputs.shape[0])): |
|
norm_vector = norm_inputs[sample_idx, :].unsqueeze(0) |
|
C += norm_vector.T @ norm_vector |
|
C *= 1/norm_vector.shape[0] |
|
|
|
_, S, _ = torch.linalg.svd(C, full_matrices=False) |
|
|
|
|
|
log_sigma_k, _ = torch.sort(torch.log(S), descending=True) |
|
return ( |
|
log_sigma_k, |
|
S |
|
) |
|
|
|
def compute_effective_rank(self, sigma_ks: torch.Tensor) -> torch.Tensor: |
|
""" |
|
references: |
|
- Roy et al. The effective rank: a measure of effective dimensionality |
|
- Garrido et al. RankMe: Assessing the Downstream Performnace of Pretrained SS Reps by their Rank. |
|
""" |
|
|
|
sigma_ks, _ = torch.sort(sigma_ks, descending=True) |
|
|
|
|
|
l1_norm_sigma = torch.norm(sigma_ks, p=1) |
|
|
|
|
|
p_k = sigma_ks / l1_norm_sigma + torch.finfo(torch.float).eps |
|
|
|
|
|
entropy = - torch.sum(p_k * torch.log(p_k)) |
|
|
|
|
|
erank = torch.exp(entropy) |
|
|
|
return erank |
|
|
|
def save_png_to_tensorboard( |
|
self, |
|
data: np.single, |
|
title: str, |
|
x_axis_label: str='Singular Value Rank Index', |
|
y_axis_label: str='Log of singular values', |
|
): |
|
|
|
current_epoch = self.trainer.current_epoch |
|
|
|
|
|
fig, ax = plt.subplots(dpi=300) |
|
ax.plot(data) |
|
ax.set_xlabel(x_axis_label) |
|
ax.set_ylabel(y_axis_label) |
|
ax.set_title(title) |
|
ax.set_ylim([-25,3]) |
|
|
|
|
|
self.logger.experiment.add_figure(f'{title}_SingularValues_{current_epoch}', fig, current_epoch) |
|
|
|
def predict_step( |
|
self, |
|
batch: torch.Tensor, |
|
batch_idx: torch.Tensor, |
|
dataloder_idx: bool=False |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch = batch |
|
outputs = self( |
|
x_t=text_batch, |
|
x_s=protein_batch, |
|
compute_masked_logits=False |
|
) |
|
|
|
z_t_joint, z_p_joint = outputs |
|
|
|
self.predict_text_joint_latents.append(z_t_joint.detach().cpu()) |
|
self.predict_seq_joint_latents.append(z_p_joint.detach().cpu()) |
|
|
|
return outputs |
|
|
|
def on_predict_epoch_end(self, outputs=None): |
|
|
|
self.predict_text_joint_latents = torch.cat(self.predict_text_joint_latents).cpu() |
|
self.predict_seq_joint_latents = torch.cat(self.predict_seq_joint_latents).cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class pfam_PL_PEN_CL(pl.LightningModule): |
|
|
|
|
|
def __init__( |
|
self, |
|
args: any, |
|
model: nn.Module, |
|
text_tokenizer: any, |
|
sequence_tokenizer: any |
|
): |
|
|
|
super().__init__() |
|
|
|
self.script_args = args |
|
|
|
|
|
self.model = model |
|
|
|
|
|
self.text_tokenizer = text_tokenizer |
|
self.sequence_tokenizer = sequence_tokenizer |
|
|
|
|
|
self.val_text_joint_latents = [] |
|
self.val_seq_joint_latents = [] |
|
|
|
|
|
self.predict_text_joint_latents = [] |
|
self.predict_seq_joint_latents = [] |
|
|
|
def forward( |
|
self, |
|
x_t: torch.Tensor, |
|
x_p: torch.Tensor, |
|
compute_masked_logits: bool=False |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
outputs = self.model( |
|
x_t=x_t, |
|
x_s=x_p, |
|
compute_masked_logits=compute_masked_logits |
|
) |
|
|
|
if compute_masked_logits: |
|
|
|
return ( |
|
outputs['text_masked_logits'], |
|
outputs['protein_masked_logits'] |
|
) |
|
else: |
|
|
|
return ( |
|
outputs['text_joint_latent'], |
|
outputs['seq_joint_latent'], |
|
) |
|
|
|
|
|
def on_train_batch_start(self, batch, batch_idx): |
|
self.batch_start_time = time.time() |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx): |
|
batch_end_time = time.time() |
|
batch_time = batch_end_time - self.batch_start_time |
|
|
|
|
|
|
|
def training_step(self, batch: torch.Tensor, batch_idx: any) -> dict: |
|
""" |
|
Execute a single training step. |
|
|
|
Given a batch of data, this function processes both Swiss-Prot and Pfam data through the model, computes |
|
various loss values including inter-modal, intra-modal, and masked language model losses for both text |
|
and protein sequences. This function also computes and logs various metrics and GPU memory usage. |
|
|
|
Parameters: |
|
- batch: The input data batch. This can include multiple types of data. |
|
- batch_idx: Index of the current batch. |
|
|
|
Steps: |
|
1. Split the data into Swiss-Prot and Pfam batches, if the batch is a list. |
|
2. Forward pass the Swiss-Prot data through the model. |
|
3. Synchronize and gather embeddings from all GPUs. |
|
4. Forward pass the Pfam data through the model. |
|
5. Synchronize and gather Pfam embeddings from all GPUs. |
|
6. Concatenate Swiss-Prot and Pfam embeddings. |
|
7. Compute inter-modal and intra-modal loss values. |
|
8. Compute masked language model logits for the concatenated batch. |
|
9. Compute masked language loss for both text and protein sequences. |
|
10. Compute and log the total loss and individual loss components. |
|
11. Compute and log performance metrics. |
|
12. Log GPU memory usage at the start of training. |
|
|
|
Returns: |
|
- Dictionary containing the total loss value. |
|
|
|
Note: |
|
This function is intended to be used within a distributed (multi-GPU) training context, as evident |
|
from the use of barriers and gathering operations. It's designed to handle batches that contain both |
|
Swiss-Prot and Pfam data, both being biological datasets used in multi-modal protein embeddings. |
|
The function utilizes both inter-modal (between modalities) and intra-modal (within the same modality) |
|
contrastive losses, as well as masked language modeling objectives similar to BERT's MLM objective. |
|
""" |
|
|
|
|
|
if isinstance(batch, list): |
|
text_batch, protein_batch, text_mask_batch, protein_mask_batch, \ |
|
pfam_text_batch, pfam_protein_batch, pfam_text_mask_batch, pfam_protein_mask_batch, \ |
|
bool_pfam_vector = batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
z_t_swiss, z_p_swiss = self( |
|
x_t=text_batch, |
|
x_p=protein_batch, |
|
compute_masked_logits=False |
|
) |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
z_t_pfam, z_p_pfam = self( |
|
x_t=pfam_text_batch, |
|
x_p=pfam_protein_batch, |
|
compute_masked_logits=False |
|
) |
|
dist.barrier() |
|
|
|
|
|
z_t_swiss_all = self.all_gather(z_t_swiss, sync_grads=True) |
|
dist.barrier() |
|
z_p_swiss_all = self.all_gather(z_p_swiss, sync_grads=True) |
|
|
|
|
|
z_t_swiss_all = z_t_swiss_all.view(-1, z_t_swiss.shape[-1]) |
|
z_p_swiss_all = z_p_swiss_all.view(-1, z_p_swiss.shape[-1]) |
|
|
|
|
|
|
|
z_t_pfam_all = self.all_gather(z_t_pfam, sync_grads=True) |
|
dist.barrier() |
|
z_p_pfam_all = self.all_gather(z_p_pfam, sync_grads=True) |
|
|
|
|
|
z_t_pfam_all = z_t_pfam_all.view(-1, z_t_pfam.shape[-1]) |
|
z_p_pfam_all = z_p_pfam_all.view(-1, z_p_pfam.shape[-1]) |
|
|
|
|
|
z_t_all = torch.cat((z_t_swiss_all, z_t_pfam_all), dim=0) |
|
z_p_all = torch.cat((z_p_swiss_all, z_p_pfam_all), dim=0) |
|
|
|
|
|
|
|
|
|
|
|
loss_align, logits = self.model.compute_inter_loss( |
|
protein_embeddings=z_p_all, |
|
text_embeddings=z_t_all, |
|
batch_size=z_p_all.shape[0] // 2 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_intra, cosine_similarity = self.model.compute_intra_loss( |
|
protein_embeddings=z_p_all, |
|
batch_size=z_p_all.shape[0] // 2 |
|
) |
|
|
|
|
|
all_text_batch = torch.cat((text_batch, pfam_text_batch), dim=0) |
|
all_protein_batch = torch.cat((protein_batch, pfam_protein_batch), dim=0) |
|
all_text_mask_batch = torch.cat((text_mask_batch, pfam_text_mask_batch), dim=0) |
|
all_protein_mask_batch = torch.cat((protein_mask_batch, pfam_protein_mask_batch), dim=0) |
|
|
|
|
|
|
|
|
|
|
|
logits_t_mask, logits_s_mask = self( |
|
x_t=all_text_mask_batch, |
|
x_p=all_protein_mask_batch, |
|
compute_masked_logits=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
loss_text_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_t_mask, |
|
targets=all_text_batch, |
|
targets_masked=all_text_mask_batch, |
|
mask_token_id=self.text_tokenizer.mask_token_id |
|
) |
|
|
|
|
|
loss_sequence_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_s_mask, |
|
targets=all_protein_batch, |
|
targets_masked=all_protein_mask_batch, |
|
mask_token_id=self.sequence_tokenizer.mask_idx |
|
) |
|
|
|
|
|
if self.script_args.dataset_type == 'pfam': |
|
|
|
loss = loss_align + loss_intra + loss_text_mask + loss_sequence_mask |
|
|
|
elif self.script_args.dataset_type == 'pfam_ablated': |
|
|
|
loss = loss_align + loss_text_mask + loss_sequence_mask |
|
else: |
|
|
|
assert self.script_args.dataset_type in ['pfam', 'pfam_ablated'], "Unexpected dataset_type value" |
|
sys.stderr.write("Unexpected dataset_type value\n") |
|
sys.exit(1) |
|
|
|
|
|
self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('train_loss_align', loss_align, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('train_loss_intra', loss_intra, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('train_loss_text_mask', loss_text_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('train_loss_seq_mask', loss_sequence_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
|
|
metric_dict = self.performance_metrics(logits=logits) |
|
for key in metric_dict: |
|
values = metric_dict[key] |
|
final_key = 'train_' + key |
|
self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
|
|
if batch_idx == 0: |
|
gpu_memory_usage = helper_tools.print_gpu_initialization() |
|
self.log(f'gpu_memory_usage', gpu_memory_usage, sync_dist=True) |
|
|
|
|
|
memory_usage = helper_tools.print_memory_usage() |
|
self.log(f'memory_usage', memory_usage, sync_dist=True) |
|
|
|
return {'loss': loss} |
|
|
|
|
|
def validation_step( |
|
self, |
|
batch: torch.Tensor, |
|
batch_idx: any, |
|
) -> dict: |
|
|
|
""" |
|
`validation_step()`: Validates a single batch of data and computes loss and performance metrics. |
|
|
|
Parameters: |
|
- `self`: Reference to the current instance of the model or module. |
|
- `batch`: Input data, which might contain text and protein sequences, their corresponding masks, and additional data from both Swiss-Prot and Pfam datasets. |
|
- `batch_idx`: Identifier for the current batch. |
|
|
|
Functionality: |
|
1. Extracts and processes data from the given batch. |
|
2. Computes embeddings for Swiss-Prot and Pfam datasets. |
|
3. Concatenates these embeddings to form a unified representation. |
|
4. Computes various loss values: inter-modal, intra-modal, and masked language losses for both biomedical texts and protein sequences. |
|
5. Logs the computed loss values and other performance metrics, highlighting metrics such as F1-score. |
|
6. Collects and appends the joint embeddings of the batch for potential future use. |
|
|
|
Returns: |
|
- A dictionary with the total validation loss for the current batch. |
|
""" |
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch, text_mask_batch, protein_mask_batch, \ |
|
pfam_text_batch, pfam_protein_batch, pfam_text_mask_batch, pfam_protein_mask_batch, \ |
|
bool_pfam_vector = batch |
|
|
|
|
|
|
|
z_t_swiss, z_p_swiss = self( |
|
x_t=text_batch, |
|
x_p=protein_batch, |
|
compute_masked_logits=False |
|
) |
|
dist.barrier() |
|
|
|
|
|
z_t_swiss_all = self.all_gather(z_t_swiss, sync_grads=True) |
|
dist.barrier() |
|
z_p_swiss_all = self.all_gather(z_p_swiss, sync_grads=True) |
|
|
|
|
|
z_t_swiss_all = z_t_swiss_all.view(-1, z_t_swiss.shape[-1]) |
|
z_p_swiss_all = z_p_swiss_all.view(-1, z_p_swiss.shape[-1]) |
|
|
|
|
|
z_t_pfam, z_p_pfam = self( |
|
x_t=pfam_text_batch, |
|
x_p=pfam_protein_batch, |
|
compute_masked_logits=False |
|
) |
|
dist.barrier() |
|
|
|
|
|
z_t_pfam_all = self.all_gather(z_t_pfam, sync_grads=True) |
|
dist.barrier() |
|
z_p_pfam_all = self.all_gather(z_p_pfam, sync_grads=True) |
|
|
|
|
|
z_t_pfam_all = z_t_pfam_all.view(-1, z_t_pfam.shape[-1]) |
|
z_p_pfam_all = z_p_pfam_all.view(-1, z_p_pfam.shape[-1]) |
|
|
|
|
|
z_t_all = torch.cat((z_t_swiss_all, z_t_pfam_all), dim=0) |
|
z_p_all = torch.cat((z_p_swiss_all, z_p_pfam_all), dim=0) |
|
|
|
|
|
loss_align, logits = self.model.compute_inter_loss( |
|
protein_embeddings=z_p_all, |
|
text_embeddings=z_t_all, |
|
batch_size=z_p_all.shape[0] // 2 |
|
) |
|
|
|
|
|
loss_intra, cosine_similarity = self.model.compute_intra_loss( |
|
protein_embeddings=z_p_all, |
|
batch_size=z_p_all.shape[0] // 2 |
|
) |
|
|
|
|
|
all_text_batch = torch.cat((text_batch, pfam_text_batch), dim=0) |
|
all_protein_batch = torch.cat((protein_batch, pfam_protein_batch), dim=0) |
|
all_text_mask_batch = torch.cat((text_mask_batch, pfam_text_mask_batch), dim=0) |
|
all_protein_mask_batch = torch.cat((protein_mask_batch, pfam_protein_mask_batch), dim=0) |
|
|
|
|
|
logits_t_mask, logits_s_mask = self( |
|
x_t=all_text_mask_batch, |
|
x_p=all_protein_mask_batch, |
|
compute_masked_logits=True |
|
) |
|
|
|
|
|
loss_text_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_t_mask, |
|
targets=all_text_batch, |
|
targets_masked=all_text_mask_batch, |
|
mask_token_id=self.text_tokenizer.mask_token_id |
|
) |
|
|
|
|
|
loss_sequence_mask = self.model.compute_masked_lang_loss( |
|
logits_masked=logits_s_mask, |
|
targets=all_protein_batch, |
|
targets_masked=all_protein_mask_batch, |
|
mask_token_id=self.sequence_tokenizer.mask_idx |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.script_args.dataset_type == 'pfam': |
|
|
|
loss = loss_align + loss_intra + loss_text_mask + loss_sequence_mask |
|
|
|
elif self.script_args.dataset_type == 'pfam_ablated': |
|
|
|
loss = loss_align + loss_text_mask + loss_sequence_mask |
|
else: |
|
|
|
assert self.script_args.dataset_type in ['pfam', 'pfam_ablated'], "Unexpected dataset_type value" |
|
sys.stderr.write("Unexpected dataset_type value\n") |
|
sys.exit(1) |
|
|
|
|
|
|
|
self.log('valid_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('valid_loss_align', loss_align, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('valid_loss_intra', loss_intra, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('valid_loss_text_mask', loss_text_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) |
|
self.log('valid_loss_seq_mask', loss_sequence_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
memory_usage = helper_tools.print_memory_usage() |
|
self.log(f'memory_usage', memory_usage, sync_dist=True) |
|
|
|
|
|
metric_dict = self.performance_metrics(logits=logits.detach().cpu()) |
|
for key in metric_dict: |
|
values = metric_dict[key] |
|
final_key = 'valid_' + key |
|
self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return {'valid_loss': loss} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self,): |
|
|
|
params = [ |
|
{"params": self.model.protein_encoder.parameters(), "lr": self.script_args.protein_encoder_lr}, |
|
{"params": self.model.text_encoder.parameters(), "lr": self.script_args.text_encoder_lr}, |
|
{"params": itertools.chain( |
|
self.model.protein_projection.parameters(), |
|
self.model.text_projection.parameters() |
|
), |
|
"lr": self.script_args.head_lr, |
|
"weight_decay": self.script_args.weight_decay} |
|
] |
|
|
|
optimizer = torch.optim.AdamW(params, weight_decay=self.script_args.weight_decay) |
|
|
|
return { |
|
"optimizer": optimizer, |
|
} |
|
|
|
@torch.no_grad() |
|
def compute_class_metrics( |
|
self, |
|
outputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
source: str |
|
) -> dict: |
|
|
|
|
|
outputs_np = outputs.numpy() |
|
targets_np = targets.numpy() |
|
|
|
|
|
accuracy = accuracy_score(targets_np, outputs_np.round()) |
|
precision = precision_score(targets_np, outputs_np.round(), average='micro') |
|
recall = recall_score(targets_np, outputs_np.round(), average='micro') |
|
f1 = f1_score(targets_np, outputs_np.round(), average='micro') |
|
|
|
return { |
|
f'{source}_accuracy': accuracy, |
|
f'{source}_precision': precision, |
|
f'{source}_recall': recall, |
|
f'{source}_f1': f1 |
|
} |
|
|
|
@torch.no_grad() |
|
def performance_metrics(self, logits: torch.Tensor) -> tuple: |
|
|
|
logits = logits.cpu().float() |
|
|
|
|
|
p_text = F.softmax(logits, dim=-1) |
|
p_seq = F.softmax(logits.T, dim=-1) |
|
p_tot = (p_seq + p_text) / 2 |
|
|
|
|
|
y_pred_text = torch.argmax(p_text, dim=-1) |
|
y_pred_seq = torch.argmax(p_seq, dim=-1) |
|
y_pred = torch.argmax(p_tot, dim=-1) |
|
y_true = torch.arange(y_pred_text.shape[0]) |
|
|
|
|
|
text_metrics = self.compute_class_metrics( |
|
outputs=y_pred_text, |
|
targets=y_true, |
|
source='text' |
|
) |
|
seq_metrics = self.compute_class_metrics( |
|
outputs=y_pred_seq, |
|
targets=y_true, |
|
source='seq' |
|
) |
|
total_metrics = self.compute_class_metrics( |
|
outputs=y_pred, |
|
targets=y_true, |
|
source='total' |
|
) |
|
|
|
|
|
combined_dict = {} |
|
combined_dict.update(text_metrics) |
|
combined_dict.update(seq_metrics) |
|
combined_dict.update(total_metrics) |
|
|
|
return combined_dict |
|
|
|
@torch.no_grad() |
|
def compute_singular(self, inputs: torch.Tensor) -> ( |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
|
|
|
|
|
|
mean_inputs = torch.mean(inputs, dim=0) |
|
norm_inputs = inputs - mean_inputs |
|
|
|
|
|
C = torch.zeros((norm_inputs.shape[-1], norm_inputs.shape[-1])) |
|
for sample_idx in range(norm_inputs.shape[0]): |
|
norm_vector = norm_inputs[sample_idx, :].unsqueeze(0) |
|
C += norm_vector.T @ norm_vector |
|
C *= 1/norm_vector.shape[0] |
|
|
|
_, S, _ = torch.linalg.svd(C, full_matrices=False) |
|
|
|
|
|
log_sigma_k, _ = torch.sort(torch.log(S), descending=True) |
|
return ( |
|
log_sigma_k, |
|
S |
|
) |
|
|
|
def compute_effective_rank(self, sigma_ks: torch.Tensor) -> torch.Tensor: |
|
""" |
|
references: |
|
- Roy et al. The effective rank: a measure of effective dimensionality |
|
- Garrido et al. RankMe: Assessing the Downstream Performnace of Pretrained SS Reps by their Rank. |
|
""" |
|
|
|
sigma_ks, _ = torch.sort(sigma_ks, descending=True) |
|
|
|
|
|
l1_norm_sigma = torch.norm(sigma_ks, p=1) |
|
|
|
|
|
p_k = sigma_ks / l1_norm_sigma + torch.finfo(torch.float).eps |
|
|
|
|
|
entropy = - torch.sum(p_k * torch.log(p_k)) |
|
|
|
|
|
erank = torch.exp(entropy) |
|
|
|
return erank |
|
|
|
def save_png_to_tensorboard( |
|
self, |
|
data: np.single, |
|
title: str, |
|
x_axis_label: str='Singular Value Rank Index', |
|
y_axis_label: str='Log of singular values', |
|
): |
|
|
|
current_epoch = self.trainer.current_epoch |
|
|
|
|
|
fig, ax = plt.subplots(dpi=300) |
|
ax.plot(data) |
|
ax.set_xlabel(x_axis_label) |
|
ax.set_ylabel(y_axis_label) |
|
ax.set_title(title) |
|
ax.set_ylim([-25,3]) |
|
|
|
|
|
self.logger.experiment.add_figure(f'{title}_SingularValues_{current_epoch}', fig, current_epoch) |
|
|
|
|
|
plt.close(fig) |
|
|
|
def predict_step( |
|
self, |
|
batch: torch.Tensor, |
|
batch_idx: torch.Tensor, |
|
dataloder_idx: bool=False |
|
) -> ( |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
|
|
if isinstance(batch, list): |
|
|
|
text_batch, protein_batch = batch |
|
outputs = self( |
|
x_t=text_batch, |
|
x_p=protein_batch, |
|
compute_masked_logits=False |
|
) |
|
|
|
z_t_joint, z_p_joint = outputs |
|
|
|
self.predict_text_joint_latents.append(z_t_joint.detach().cpu()) |
|
self.predict_seq_joint_latents.append(z_p_joint.detach().cpu()) |
|
|
|
return outputs |
|
|
|
def on_predict_epoch_end(self, outputs=None): |
|
|
|
self.predict_text_joint_latents = torch.cat(self.predict_text_joint_latents).cpu() |
|
self.predict_seq_joint_latents = torch.cat(self.predict_seq_joint_latents).cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
class PL_Facilitator(pl.LightningModule): |
|
|
|
def __init__( |
|
self, |
|
args: any |
|
): |
|
|
|
super().__init__() |
|
|
|
|
|
self.args = args |
|
|
|
|
|
self.model = mod.Facilitator( |
|
in_dim=self.args.emb_dim, |
|
hid_dim=self.args.hid_dim, |
|
out_dim=self.args.emb_dim, |
|
dropout=self.args.dropout |
|
) |
|
|
|
self.text_to_protein_joint_embeddings = [] |
|
|
|
def forward( |
|
self, |
|
z_t: torch.Tensor, |
|
) -> torch.Tensor: |
|
|
|
|
|
z_t_to_p = self.model(z_t) |
|
|
|
return z_t_to_p |
|
|
|
|
|
|
|
def training_step(self, batch: torch.Tensor, batch_id: any) -> dict: |
|
|
|
|
|
if isinstance(batch, list): |
|
text_embeddings, protein_embeddings = batch |
|
|
|
|
|
z_t_to_p = self(z_t=text_embeddings) |
|
|
|
|
|
loss = self.model.compute_loss( |
|
output=z_t_to_p, |
|
target=protein_embeddings, |
|
loss_option=self.args.loss_type |
|
) |
|
|
|
|
|
self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
return {'loss': loss} |
|
|
|
|
|
def validation_step(self, batch: torch.Tensor, batch_id: any) -> dict: |
|
|
|
|
|
if isinstance(batch, list): |
|
text_embeddings, protein_embeddings = batch |
|
|
|
|
|
z_t_to_p = self(z_t=text_embeddings) |
|
|
|
|
|
loss = self.model.compute_loss( |
|
output=z_t_to_p, |
|
target=protein_embeddings, |
|
loss_option=self.args.loss_type |
|
) |
|
|
|
|
|
self.log('valid_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
return {'loss': loss} |
|
|
|
|
|
def configure_optimizers(self,): |
|
|
|
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) |
|
|
|
return { |
|
"optimizer": optimizer |
|
} |
|
|
|
|
|
def predict_step(self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = None) -> torch.Tensor: |
|
""" |
|
Defines a single prediction (inference) step. |
|
""" |
|
|
|
|
|
|
|
if isinstance(batch, list): |
|
text_embeddings, _ = batch |
|
else: |
|
text_embeddings = batch |
|
|
|
|
|
z_t_to_p = self(z_t=text_embeddings) |
|
self.text_to_protein_joint_embeddings.append(z_t_to_p.detach().cpu()) |
|
|
|
return z_t_to_p |
|
|
|
def on_predict_epoch_end(self, outputs=None): |
|
|
|
self.text_to_protein_joint_embeddings = torch.cat(self.text_to_protein_joint_embeddings).cpu() |
|
|