|
import torch |
|
from torch import nn, optim |
|
from torch.nn import functional as F |
|
from torch.distributions import OneHotCategorical |
|
from transformers.optimization import Adafactor |
|
|
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning import Trainer, seed_everything |
|
from pytorch_lightning.callbacks import EarlyStopping |
|
|
|
import functools |
|
import math |
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.distributed.fsdp.wrap import ( |
|
size_based_auto_wrap_policy, |
|
enable_wrap, |
|
wrap |
|
) |
|
|
|
import deepspeed |
|
from deepspeed.ops.adam import DeepSpeedCPUAdam |
|
|
|
from sklearn.model_selection import train_test_split |
|
|
|
from Stage3_source.DSEma import moving_average, clone_zero_model |
|
import Stage3_source.transformer_training_helper as trainer_tools |
|
import Stage3_source.helper_funcs as helper_tools |
|
import Stage3_source.eval_metrics as eval_funcs |
|
import Stage3_source.preprocess as prep |
|
|
|
import copy |
|
|
|
from torch.utils.data import DataLoader |
|
import pandas as pd |
|
|
|
from transformers import get_cosine_schedule_with_warmup |
|
|
|
class PL_ProtARDM(pl.LightningModule): |
|
|
|
|
|
def __init__( |
|
self, |
|
args: any, |
|
model: nn.Module, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self.script_args = args |
|
|
|
|
|
self.model = model |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
y_c: torch.Tensor, |
|
ema=False, |
|
) -> torch.Tensor: |
|
|
|
if ema: |
|
logits = self.ema_model(x=x, t=t.view(-1,), y_c=y_c) |
|
else: |
|
logits = self.model(x=x, t=t.view(-1,), y_c=y_c) |
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self, ): |
|
|
|
if self.script_args.choose_optim == 'AdamW': |
|
|
|
if isinstance(self, FSDP): |
|
print("Enter FSDP") |
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay) |
|
|
|
else: |
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay) |
|
|
|
elif self.script_args.choose_optim == 'AdaFactor': |
|
optimizer = Adafactor(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay, relative_step=False) |
|
|
|
elif self.script_args.choose_optim == 'Adam': |
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.script_args.lr) |
|
|
|
elif self.script_args.choose_optim == 'DeepSpeedCPUAdam': |
|
optimizer = DeepSpeedCPUAdam(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay) |
|
|
|
if self.script_args.scheduler_gamma is not None: |
|
if isinstance(self.script_args.scheduler_gamma, str): |
|
if 'coswarmup' == self.script_args.scheduler_gamma.lower(): |
|
print(f'Using cossine warmup scheduler with decay') |
|
num_warmup_steps=self.script_args.traindata_len |
|
num_training_steps=self.script_args.traindata_len*self.script_args.epochs |
|
print(f'Num_warmup_steps={num_warmup_steps}') |
|
print(f'Num_training_steps={num_training_steps}') |
|
|
|
def _get_cosine_schedule_with_warmup_lr_lambda( |
|
current_step: int, num_warmup_steps: int, num_training_steps: int, num_cycles: float |
|
): |
|
if current_step < num_warmup_steps: |
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
|
|
|
lr_lambda = functools.partial( |
|
_get_cosine_schedule_with_warmup_lr_lambda, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_training_steps, |
|
num_cycles=0.5, |
|
) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1), |
|
"interval": "step", |
|
}, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
print(f'Using Exponential learning rate decay / epoch with factor: {self.script_args.scheduler_gamma}') |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.script_args.scheduler_gamma, verbose=True), |
|
"interval": "epoch", |
|
}, |
|
} |
|
else: |
|
return optimizer |
|
|
|
|
|
|
|
|
|
def common_step( |
|
self, |
|
realization: torch.Tensor, |
|
realization_idx: any, |
|
stage: str) -> dict: |
|
|
|
if isinstance(realization, list): |
|
|
|
|
|
y_c = realization[1] |
|
|
|
|
|
realization = realization[0] |
|
batch_size, seq_length = realization.size() |
|
|
|
realization = realization.reshape(batch_size, 1, seq_length).long() |
|
|
|
train_tuple = self.cond_elbo_objective( |
|
realization=realization, |
|
y_c=y_c, |
|
realization_idx=realization_idx, |
|
stage=stage, |
|
ema=True if 'ema' in stage.lower() else False, |
|
) |
|
|
|
if len(train_tuple) == 1: |
|
loss = train_tuple[0] |
|
else: |
|
loss = train_tuple[0] |
|
metrics = train_tuple[1] |
|
|
|
if realization_idx == 0: |
|
gpu_memory_usage = helper_tools.print_gpu_initialization() |
|
self.log(f"{stage}_gpu_memory_usage", gpu_memory_usage, sync_dist=True) |
|
|
|
sync_dist = True if 'val' in stage else False |
|
|
|
self.log(f"{stage}_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
|
|
if len(train_tuple) > 1: |
|
self.log(f"{stage}_prev_hard_acc", metrics[0], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_prev_soft_acc", metrics[1], on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_fut_hard_acc", metrics[2], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_fut_soft_acc", metrics[3], on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_current_hard_acc", metrics[4], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_current_soft_acc", metrics[5], on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_current_ppl", metrics[6], on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_prev_ppl", metrics[7], on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_fut_ppl", metrics[8], on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
self.log(f"{stage}_pos_entropy", metrics[9], on_step=True, on_epoch=True, sync_dist=sync_dist) |
|
|
|
torch.cuda.empty_cache() |
|
return {'loss': loss} |
|
|
|
def training_step( |
|
self, |
|
realization: torch.Tensor, |
|
realization_idx: any): |
|
return self.common_step(realization, realization_idx, stage='train') |
|
|
|
def validation_step( |
|
self, |
|
realization: torch.Tensor, |
|
realization_idx: any): |
|
self.common_step(realization, realization_idx, stage='val') |
|
|
|
|
|
def apply_OneHotCat(self, probs: torch.Tensor) -> any: |
|
return OneHotCategorical(probs=probs.permute(0,2,1)) |
|
|
|
|
|
def cond_elbo_objective( |
|
self, |
|
realization: torch.Tensor, |
|
y_c: torch.Tensor, |
|
realization_idx: any, |
|
stage: str, |
|
ema=False, |
|
): |
|
|
|
bs, channel, seq_length = realization.size() |
|
|
|
|
|
sampled_random_path = trainer_tools.sample_random_path(bs, seq_length, device=self.script_args.device) |
|
|
|
idx = trainer_tools.sample_random_index_for_sampling(bs, seq_length, device=self.script_args.device, option='random') |
|
|
|
random_path_mask = trainer_tools.create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length) |
|
|
|
current_path_mask = trainer_tools.create_sampling_location_mask(sampled_random_path, idx, bs, seq_length) |
|
|
|
future_path_mask = trainer_tools.create_mask_at_future_path_index(sampled_random_path, idx, bs, seq_length) |
|
|
|
real_tokens, bs, seq_length = trainer_tools.create_token_labels(self.script_args, realization) |
|
|
|
|
|
real_token_masked = trainer_tools.mask_realizations(real_tokens, random_path_mask) |
|
|
|
|
|
logits = self(x=real_token_masked, t=idx, y_c=y_c, ema=ema) |
|
|
|
conditional_prob = OneHotCategorical(logits=logits.permute(0,2,1)) |
|
|
|
|
|
log_prob = trainer_tools.log_prob_of_realization(self.script_args, conditional_prob, real_tokens) |
|
|
|
|
|
|
|
log_prob_unsampled = trainer_tools.log_prob_of_unsampled_locations(log_prob, real_token_masked) |
|
|
|
|
|
|
|
|
|
log_prob_weighted = trainer_tools.weight_log_prob(log_prob_unsampled, idx, seq_length) |
|
|
|
loss = trainer_tools.compute_average_loss_for_batch(log_prob_weighted) |
|
|
|
|
|
probs = F.softmax(logits, dim=1) |
|
metrics = self.performance_step( |
|
real_tokens=real_tokens.cpu(), |
|
idx=idx.cpu(), |
|
sampled_random_path=sampled_random_path.cpu().float(), |
|
probs=probs.cpu().float(), |
|
conditional_prob=conditional_prob) |
|
|
|
return loss, metrics |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def performance_step( |
|
self, |
|
real_tokens: torch.Tensor, |
|
idx: torch.Tensor, |
|
sampled_random_path: torch.Tensor, |
|
probs: torch.Tensor, |
|
conditional_prob: torch.Tensor |
|
) -> tuple: |
|
|
|
|
|
|
|
sample_seq = torch.argmax(trainer_tools.sample_from_conditional(conditional_prob).cpu(), dim=1) |
|
|
|
|
|
prev_B_hard_acc, prev_B_soft_acc, fut_B_hard_acc, fut_B_soft_acc, current_B_hard_acc, current_B_soft_acc = eval_funcs.compute_acc_given_time_pos( |
|
real_tokens=real_tokens, |
|
sample_seq=sample_seq, |
|
sample_path=sampled_random_path, |
|
idx=idx |
|
) |
|
|
|
|
|
current_ppl, prev_ppl, fut_ppl = eval_funcs.compute_ppl_given_time_pos( |
|
probs=probs, |
|
sample_path=sampled_random_path, |
|
idx=idx |
|
) |
|
|
|
|
|
pos_entropy = trainer_tools.compute_pos_entropy(probs=probs).mean().item() |
|
|
|
metric_evals = ( |
|
prev_B_hard_acc, |
|
prev_B_soft_acc, |
|
fut_B_hard_acc, |
|
fut_B_soft_acc, |
|
current_B_hard_acc, |
|
current_B_soft_acc, |
|
current_ppl, |
|
prev_ppl, |
|
fut_ppl, |
|
pos_entropy |
|
) |
|
|
|
return metric_evals |
|
|
|
|
|
|
|
class PFamDataModule(pl.LightningDataModule): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
|
|
|
|
|
|
data = self.load_data() |
|
|
|
num_seq_list, text_emb_list = prep.prepare_protein_data( |
|
args=args, |
|
data_dict=data |
|
) |
|
|
|
print('Performing 80/20 random train/val split') |
|
num_seq_list_train, num_seq_list_val, text_emb_train, text_emb_val = train_test_split(num_seq_list, |
|
text_emb_list, |
|
test_size=args.valid_size, |
|
|
|
random_state=args.seed) |
|
print(f'Number of training samples: {len(num_seq_list_train)}') |
|
print(f'Number of validation samples: {len(num_seq_list_val)}') |
|
|
|
self.train_dataset = prep.protein_dataset( |
|
num_seq_list=num_seq_list_train, |
|
text_emb=text_emb_train |
|
) |
|
|
|
self.val_dataset = prep.protein_dataset( |
|
num_seq_list=num_seq_list_val, |
|
text_emb=text_emb_val |
|
) |
|
|
|
def load_data(self): |
|
|
|
try: |
|
|
|
print(self.args.swissprot_data_root, self.args.pfam_data_root) |
|
|
|
if self.args.swissprot_data_root != "None": |
|
swissprot_data = torch.load(self.args.swissprot_data_root) |
|
else: |
|
swissprot_data=None |
|
|
|
if self.args.pfam_data_root != "None": |
|
pfam_data = torch.load(self.args.pfam_data_root) |
|
else: |
|
pfam_data=None |
|
|
|
if (self.args.swissprot_data_root != "None") and (self.args.pfam_data_root != "None"): |
|
return self.merge_and_append_values(dict1=swissprot_data, dict2=pfam_data) |
|
elif self.args.swissprot_data_root == "None": |
|
return pfam_data |
|
elif self.args.pfam_data_root == "None": |
|
return swissprot_data |
|
else: |
|
raise ValueError('Both SwissProt and Pfam datasets are unavailable.') |
|
|
|
except FileNotFoundError as e: |
|
raise FileNotFoundError(f"Data file not found: {e}") |
|
|
|
|
|
def merge_and_append_values(self, dict1, dict2): |
|
|
|
merged_dict = {} |
|
|
|
|
|
all_keys = set(dict1) | set(dict2) |
|
|
|
for key in all_keys: |
|
values = [] |
|
if key in dict1: |
|
values.append(dict1[key]) |
|
if key in dict2: |
|
values.append(dict2[key]) |
|
|
|
|
|
|
|
merged_dict[key] = [item for sublist in values for item in (sublist if isinstance(sublist, list) else [sublist])] |
|
|
|
return merged_dict |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.args.batch_size, |
|
num_workers=self.args.num_workers, |
|
shuffle=True |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.args.batch_size, |
|
num_workers=self.args.num_workers, |
|
shuffle=False |
|
) |
|
|