BioM3 / Stage3_source /PL_wrapper.py
Niksa Praljak
Add scripts for ProteoScribe Sampling
c865888
raw
history blame
18 kB
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.distributions import OneHotCategorical
from transformers.optimization import Adafactor
# PL functions
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping
import functools
import math
#from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap
)
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam
from sklearn.model_selection import train_test_split
from Stage3_source.DSEma import moving_average, clone_zero_model
import Stage3_source.transformer_training_helper as trainer_tools
import Stage3_source.helper_funcs as helper_tools
import Stage3_source.eval_metrics as eval_funcs
import Stage3_source.preprocess as prep
import copy
from torch.utils.data import DataLoader
import pandas as pd
from transformers import get_cosine_schedule_with_warmup
class PL_ProtARDM(pl.LightningModule):
def __init__(
self,
args: any,
model: nn.Module,
#ema_model: nn.Module,
):
super().__init__()
#self.save_hyperparameters()
# arguments
self.script_args = args
# the whole model
self.model = model
#self.ema_model = ema_model
#clone_zero_model(self.model, self.ema_model, zero_stage=3)
##self.ema_model = copy.deepcopy(self.model)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y_c: torch.Tensor,
ema=False,
) -> torch.Tensor:
if ema:
logits = self.ema_model(x=x, t=t.view(-1,), y_c=y_c)
else:
logits = self.model(x=x, t=t.view(-1,), y_c=y_c)
return logits
#return F.softmax(logits, dim=1)
#def on_train_batch_end(self, *args, **kwargs):
# clone_zero_model(self.model, self.ema_model, zero_stage=3)
# #moving_average(self.model, self.ema_model, beta=0.0, zero_stage=3)
def configure_optimizers(self, ):
if self.script_args.choose_optim == 'AdamW':
if isinstance(self, FSDP):
print("Enter FSDP")
optimizer = torch.optim.AdamW(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay)
else:
optimizer = torch.optim.AdamW(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay)
elif self.script_args.choose_optim == 'AdaFactor':
optimizer = Adafactor(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay, relative_step=False)
elif self.script_args.choose_optim == 'Adam':
optimizer = torch.optim.Adam(self.parameters(), lr=self.script_args.lr)
elif self.script_args.choose_optim == 'DeepSpeedCPUAdam':
optimizer = DeepSpeedCPUAdam(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay)
if self.script_args.scheduler_gamma is not None:
if isinstance(self.script_args.scheduler_gamma, str):
if 'coswarmup' == self.script_args.scheduler_gamma.lower():
print(f'Using cossine warmup scheduler with decay')
num_warmup_steps=self.script_args.traindata_len
num_training_steps=self.script_args.traindata_len*self.script_args.epochs
print(f'Num_warmup_steps={num_warmup_steps}')
print(f'Num_training_steps={num_training_steps}')
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int, num_warmup_steps: int, num_training_steps: int, num_cycles: float
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
lr_lambda = functools.partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=0.5,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1),
"interval": "step",
},
}
#return {
# "optimizer": optimizer,
# "lr_scheduler": {
# "scheduler": get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps),
# "interval": "step",
# },
#}
else:
print(f'Using Exponential learning rate decay / epoch with factor: {self.script_args.scheduler_gamma}')
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.script_args.scheduler_gamma, verbose=True),
"interval": "epoch",
},
}
else:
return optimizer
#else:
# print("Please make choose_option variable from these options: 'AdamW', 'AdaFactor', 'Adam', 'DeepSpeedCPUAdam'")
def common_step(
self,
realization: torch.Tensor,
realization_idx: any,
stage: str) -> dict:
if isinstance(realization, list):
# class labels
y_c = realization[1]#.long()
# input samples
realization = realization[0]
batch_size, seq_length = realization.size()
realization = realization.reshape(batch_size, 1, seq_length).long()
train_tuple = self.cond_elbo_objective(
realization=realization,
y_c=y_c,
realization_idx=realization_idx,
stage=stage,
ema=True if 'ema' in stage.lower() else False,
)
if len(train_tuple) == 1:
loss = train_tuple[0]
else:
loss = train_tuple[0]
metrics = train_tuple[1]
if realization_idx == 0:
gpu_memory_usage = helper_tools.print_gpu_initialization()
self.log(f"{stage}_gpu_memory_usage", gpu_memory_usage, sync_dist=True)
sync_dist = True if 'val' in stage else False
# track loss
self.log(f"{stage}_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
# track performance metrics
if len(train_tuple) > 1:
self.log(f"{stage}_prev_hard_acc", metrics[0], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_prev_soft_acc", metrics[1], on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_fut_hard_acc", metrics[2], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_fut_soft_acc", metrics[3], on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_current_hard_acc", metrics[4], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_current_soft_acc", metrics[5], on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_current_ppl", metrics[6], on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_prev_ppl", metrics[7], on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_fut_ppl", metrics[8], on_step=True, on_epoch=True, sync_dist=sync_dist)
self.log(f"{stage}_pos_entropy", metrics[9], on_step=True, on_epoch=True, sync_dist=sync_dist)
torch.cuda.empty_cache()
return {'loss': loss}
def training_step(
self,
realization: torch.Tensor,
realization_idx: any):
return self.common_step(realization, realization_idx, stage='train')
def validation_step(
self,
realization: torch.Tensor,
realization_idx: any):
self.common_step(realization, realization_idx, stage='val')
#self.common_step(realization, realization_idx, stage='EMA_val')
def apply_OneHotCat(self, probs: torch.Tensor) -> any:
return OneHotCategorical(probs=probs.permute(0,2,1))
#return OneHotCategorical(probs=F.softmax(probs.permute(0,2,1), dim=-1))
def cond_elbo_objective(
self,
realization: torch.Tensor,
y_c: torch.Tensor,
realization_idx: any,
stage: str,
ema=False,
):
bs, channel, seq_length = realization.size()
# get a batch of random sampling paths
sampled_random_path = trainer_tools.sample_random_path(bs, seq_length, device=self.script_args.device)
# sample a set of random smapling steps for each individual training sequences in the current batch
idx = trainer_tools.sample_random_index_for_sampling(bs, seq_length, device=self.script_args.device, option='random')
# we create a mask that masks the location were we've already sampled
random_path_mask = trainer_tools.create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length)
# create a mask that masks the location where we are currently sampling
current_path_mask = trainer_tools.create_sampling_location_mask(sampled_random_path, idx, bs, seq_length)
# future sampling locations (i.e. >t)
future_path_mask = trainer_tools.create_mask_at_future_path_index(sampled_random_path, idx, bs, seq_length)
# tokenize realization
real_tokens, bs, seq_length = trainer_tools.create_token_labels(self.script_args, realization)
#real_tokens = realization.clone().squeeze(1)
# mask realizations
real_token_masked = trainer_tools.mask_realizations(real_tokens, random_path_mask)
# conditional probs
#probs = self(x=real_token_masked, t=idx, y_c=y_c, ema=ema)
logits = self(x=real_token_masked, t=idx, y_c=y_c, ema=ema)
conditional_prob = OneHotCategorical(logits=logits.permute(0,2,1))
#conditional_prob = self.apply_OneHotCat(probs=probs)
# evaluate the value of the log prob for the given realization
log_prob = trainer_tools.log_prob_of_realization(self.script_args, conditional_prob, real_tokens)
# compute an average over all the unsampled
#log_prob_unsampled = trainer_tools.log_prob_of_unsampled_locations(log_prob.to(self.script_args.device), real_token_masked.to(self.script_args.device))
log_prob_unsampled = trainer_tools.log_prob_of_unsampled_locations(log_prob, real_token_masked)
#log_prob_unsampled = trainer_tools.log_prob_of_unsampled_locations(log_prob, real_token_masked, real_tokens)
# compute an average loss i.e. negative average log-likelihood over the batch elements
log_prob_weighted = trainer_tools.weight_log_prob(log_prob_unsampled, idx, seq_length)
# compute an average loss i.e. negative average log-likelihood over the batch elements
loss = trainer_tools.compute_average_loss_for_batch(log_prob_weighted)
#if 'val' in stage:
probs = F.softmax(logits, dim=1)
metrics = self.performance_step(
real_tokens=real_tokens.cpu(),
idx=idx.cpu(),
sampled_random_path=sampled_random_path.cpu().float(),
probs=probs.cpu().float(),
conditional_prob=conditional_prob)
return loss, metrics
# return loss,
@torch.no_grad()
def performance_step(
self,
real_tokens: torch.Tensor,
idx: torch.Tensor,
sampled_random_path: torch.Tensor,
probs: torch.Tensor,
conditional_prob: torch.Tensor
) -> tuple:
# create numerical token sequence
sample_seq = torch.argmax(trainer_tools.sample_from_conditional(conditional_prob).cpu(), dim=1)
# eval prev positions in terms of time
prev_B_hard_acc, prev_B_soft_acc, fut_B_hard_acc, fut_B_soft_acc, current_B_hard_acc, current_B_soft_acc = eval_funcs.compute_acc_given_time_pos(
real_tokens=real_tokens,
sample_seq=sample_seq,
sample_path=sampled_random_path,
idx=idx
)
# compute ppl given time position
current_ppl, prev_ppl, fut_ppl = eval_funcs.compute_ppl_given_time_pos(
probs=probs,
sample_path=sampled_random_path,
idx=idx
)
# average positional entropy
pos_entropy = trainer_tools.compute_pos_entropy(probs=probs).mean().item()
metric_evals = (
prev_B_hard_acc,
prev_B_soft_acc,
fut_B_hard_acc,
fut_B_soft_acc,
current_B_hard_acc,
current_B_soft_acc,
current_ppl,
prev_ppl,
fut_ppl,
pos_entropy
)
return metric_evals
class PFamDataModule(pl.LightningDataModule):
def __init__(self, args):
super().__init__()
self.args = args
#df = pd.read_csv(args.data_root)
#data = torch.load(args.data_root)
data = self.load_data()
num_seq_list, text_emb_list = prep.prepare_protein_data(
args=args,
data_dict=data
)
print('Performing 80/20 random train/val split')
num_seq_list_train, num_seq_list_val, text_emb_train, text_emb_val = train_test_split(num_seq_list,
text_emb_list,
test_size=args.valid_size,
#stratify=class_label_list,
random_state=args.seed)
print(f'Number of training samples: {len(num_seq_list_train)}')
print(f'Number of validation samples: {len(num_seq_list_val)}')
self.train_dataset = prep.protein_dataset(
num_seq_list=num_seq_list_train,
text_emb=text_emb_train
)
self.val_dataset = prep.protein_dataset(
num_seq_list=num_seq_list_val,
text_emb=text_emb_val
)
def load_data(self):
try:
print(self.args.swissprot_data_root, self.args.pfam_data_root)
if self.args.swissprot_data_root != "None":
swissprot_data = torch.load(self.args.swissprot_data_root)
else:
swissprot_data=None
if self.args.pfam_data_root != "None":
pfam_data = torch.load(self.args.pfam_data_root)
else:
pfam_data=None
if (self.args.swissprot_data_root != "None") and (self.args.pfam_data_root != "None"):
return self.merge_and_append_values(dict1=swissprot_data, dict2=pfam_data)
elif self.args.swissprot_data_root == "None":
return pfam_data
elif self.args.pfam_data_root == "None":
return swissprot_data
else:
raise ValueError('Both SwissProt and Pfam datasets are unavailable.')
except FileNotFoundError as e:
raise FileNotFoundError(f"Data file not found: {e}")
def merge_and_append_values(self, dict1, dict2):
merged_dict = {}
# Combine all keys from both dictionaries
all_keys = set(dict1) | set(dict2)
for key in all_keys:
values = []
if key in dict1:
values.append(dict1[key])
if key in dict2:
values.append(dict2[key])
# Merge values for each key
# This merges lists or appends non-list values
merged_dict[key] = [item for sublist in values for item in (sublist if isinstance(sublist, list) else [sublist])]
return merged_dict
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
shuffle=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
shuffle=False
)