import json import sys from typing import Optional # This import must be on top to set the environment variables before importing other modules import env_consts import time import os from lightning.pytorch import seed_everything import lightning.pytorch as pl import torch from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.profilers import AdvancedProfiler from dockformerpp.config import model_config from dockformerpp.data.data_modules import OpenFoldDataModule, DockFormerDataModule from dockformerpp.model.model import AlphaFold from dockformerpp.utils import residue_constants from dockformerpp.utils.exponential_moving_average import ExponentialMovingAverage from dockformerpp.utils.loss import AlphaFoldLoss, lddt_ca from dockformerpp.utils.lr_schedulers import AlphaFoldLRScheduler from dockformerpp.utils.script_utils import get_latest_checkpoint from dockformerpp.utils.superimposition import superimpose from dockformerpp.utils.tensor_utils import tensor_tree_map from dockformerpp.utils.validation_metrics import ( drmsd, gdt_ts, gdt_ha, rmsd, ) class ModelWrapper(pl.LightningModule): def __init__(self, config): super(ModelWrapper, self).__init__() self.config = config self.model = AlphaFold(config) self.loss = AlphaFoldLoss(config.loss) self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) self.cached_weights = None self.last_lr_step = -1 self.aggregated_metrics = {} self.log_agg_every_n_steps = 50 # match Trainer(log_every_n_steps=50) def forward(self, batch): return self.model(batch) def _log(self, loss_breakdown, batch, outputs, train=True): phase = "train" if train else "val" for loss_name, indiv_loss in loss_breakdown.items(): # print("logging loss", loss_name, indiv_loss, flush=True) self.log( f"{phase}/{loss_name}", indiv_loss, on_step=train, on_epoch=(not train), logger=True, sync_dist=True ) if train: agg_name = f"{phase}/{loss_name}_agg" if agg_name not in self.aggregated_metrics: self.aggregated_metrics[agg_name] = [] self.aggregated_metrics[agg_name].append(float(indiv_loss)) self.log( f"{phase}/{loss_name}_epoch", indiv_loss, on_step=False, on_epoch=True, logger=True, sync_dist=True ) # print("logging validation metrics", flush=True) with torch.no_grad(): other_metrics = self._compute_validation_metrics( batch, outputs, superimposition_metrics=(not train) ) for k, v in other_metrics.items(): # print("logging metric", k, v, flush=True) if train: agg_name = f"{phase}/{k}_agg" if agg_name not in self.aggregated_metrics: self.aggregated_metrics[agg_name] = [] self.aggregated_metrics[agg_name].append(float(torch.mean(v))) self.log( f"{phase}/{k}", torch.mean(v), on_step=False, on_epoch=True, logger=True, sync_dist=True ) if train and any([len(v) >= self.log_agg_every_n_steps for v in self.aggregated_metrics.values()]): for k, v in self.aggregated_metrics.items(): print("logging agg", k, len(v), sum(v) / len(v), flush=True) self.log(k, sum(v) / len(v), on_step=True, on_epoch=False, logger=True, sync_dist=True) self.aggregated_metrics[k] = [] def training_step(self, batch, batch_idx): if self.ema.device != batch["aatype"].device: self.ema.to(batch["aatype"].device) # ground_truth = batch.pop('gt_features', None) # Run the model # print("running model", round(time.time() % 10000, 3), flush=True) outputs = self(batch) # Remove the recycling dimension batch = tensor_tree_map(lambda t: t[..., -1], batch) # print("running loss", round(time.time() % 10000, 3), flush=True) # Compute loss loss, loss_breakdown = self.loss( outputs, batch, _return_breakdown=True ) # Log it self._log(loss_breakdown, batch, outputs) # print("loss done", round(time.time() % 10000, 3), flush=True) return loss def on_before_zero_grad(self, *args, **kwargs): self.ema.update(self.model) def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights if self.cached_weights is None: # model.state_dict() contains references to model weights rather # than copies. Therefore, we need to clone them before calling # load_state_dict(). clone_param = lambda t: t.detach().clone() self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) self.model.load_state_dict(self.ema.state_dict()["params"]) # Run the model outputs = self(batch) batch = tensor_tree_map(lambda t: t[..., -1], batch) batch["use_clamped_fape"] = 0. # Compute loss and other metrics _, loss_breakdown = self.loss( outputs, batch, _return_breakdown=True ) self._log(loss_breakdown, batch, outputs, train=False) def on_validation_epoch_end(self): # Restore the model weights to normal self.model.load_state_dict(self.cached_weights) self.cached_weights = None def _compute_validation_metrics(self, batch, outputs, superimposition_metrics=False ): metrics = {} joined_all_atom_mask = batch["atom37_atom_exists_in_gt"] protein_r_all_atom_mask = torch.repeat_interleave( batch["protein_r_mask"], 37, dim=-1).view(*joined_all_atom_mask.shape) protein_l_all_atom_mask = torch.repeat_interleave( batch["protein_l_mask"], 37, dim=-1).view(*joined_all_atom_mask.shape) lddt_ca_score = lddt_ca( outputs["final_atom_positions"], batch["atom37_gt_positions"], joined_all_atom_mask, eps=self.config.globals.eps, per_residue=False, ) metrics["lddt_ca_joined"] = lddt_ca_score lddt_ca_score = lddt_ca( outputs["final_atom_positions"], batch["atom37_gt_positions"], protein_r_all_atom_mask, eps=self.config.globals.eps, per_residue=False, ) metrics["lddt_ca_r"] = lddt_ca_score lddt_ca_score = lddt_ca( outputs["final_atom_positions"], batch["atom37_gt_positions"], protein_l_all_atom_mask, eps=self.config.globals.eps, per_residue=False, ) metrics["lddt_ca_l"] = lddt_ca_score ca_pos = residue_constants.atom_order["CA"] gt_coords_ca = batch["atom37_gt_positions"][..., ca_pos, :] pred_coords_ca = outputs["final_atom_positions"][..., ca_pos, :] drmsd_ca_score = drmsd( pred_coords_ca, gt_coords_ca, mask=batch["structural_mask"], # still required here to compute n ) metrics["drmsd_ca_joined"] = drmsd_ca_score drmsd_ca_score = drmsd( pred_coords_ca, gt_coords_ca, mask=batch["protein_r_mask"], ) metrics["drmsd_ca_r"] = drmsd_ca_score drmsd_ca_score = drmsd( pred_coords_ca, gt_coords_ca, mask=batch["protein_l_mask"], ) metrics["drmsd_ca_l"] = drmsd_ca_score # --- inter contacts gt_contacts = batch["gt_inter_contacts"] pred_contacts = torch.sigmoid(outputs["inter_contact_logits"].clone().detach()).squeeze(-1) pred_contacts = (pred_contacts > 0.5).float() pred_contacts = pred_contacts * batch["inter_pair_mask"] # Calculate True Positives, False Positives, and False Negatives tp = torch.sum((gt_contacts == 1) & (pred_contacts == 1)) fp = torch.sum((gt_contacts == 0) & (pred_contacts == 1)) fn = torch.sum((gt_contacts == 1) & (pred_contacts == 0)) # Calculate Recall and Precision recall = tp / (tp + fn) if (tp + fn) > 0 else tp.float() precision = tp / (tp + fp) if (tp + fp) > 0 else tp.float() metrics["inter_contacts_recall"] = recall.clone().detach() metrics["inter_contacts_precision"] = precision.clone().detach() # --- Affinity gt_affinity = batch["affinity"].squeeze(-1) affinity_linspace = torch.linspace(0, 15, 32, device=batch["affinity"].device) pred_affinity_2d = torch.sum( torch.softmax(outputs["affinity_2d_logits"].clone().detach(), -1) * affinity_linspace, dim=-1) pred_affinity_cls = torch.sum( torch.softmax(outputs["affinity_cls_logits"].clone().detach(), -1) * affinity_linspace, dim=-1) aff_loss_factor = batch["affinity_loss_factor"].squeeze() metrics["affinity_dist_2d"] = (torch.abs(gt_affinity - pred_affinity_2d) * aff_loss_factor).sum() / aff_loss_factor.sum() metrics["affinity_dist_cls"] = (torch.abs(gt_affinity - pred_affinity_cls) * aff_loss_factor).sum() / aff_loss_factor.sum() metrics["affinity_dist_avg"] = (torch.abs(gt_affinity - (pred_affinity_cls + pred_affinity_2d) / 2) * aff_loss_factor).sum() / aff_loss_factor.sum() if superimposition_metrics: superimposed_pred, alignment_rmsd, rots, transs = superimpose( gt_coords_ca, pred_coords_ca, batch["structural_mask"], ) gdt_ts_score = gdt_ts( superimposed_pred, gt_coords_ca, batch["structural_mask"] ) gdt_ha_score = gdt_ha( superimposed_pred, gt_coords_ca, batch["structural_mask"] ) metrics["alignment_rmsd_joined"] = alignment_rmsd metrics["gdt_ts_joined"] = gdt_ts_score metrics["gdt_ha_joined"] = gdt_ha_score superimposed_pred_l, alignment_rmsd, rots, transs = superimpose( gt_coords_ca, pred_coords_ca, batch["protein_l_mask"], ) metrics["alignment_rmsd_l"] = alignment_rmsd superimposed_pred_r, alignment_rmsd, rots, transs = superimpose( gt_coords_ca, pred_coords_ca, batch["protein_r_mask"], ) metrics["alignment_rmsd_r"] = alignment_rmsd superimposed_l_by_r_trans_coords = pred_coords_ca @ rots + transs[:, None, :] l_by_r_alignment_rmsds = rmsd(gt_coords_ca, superimposed_l_by_r_trans_coords, mask=batch["protein_l_mask"]) metrics["alignment_rmsd_l_by_r"] = l_by_r_alignment_rmsds.mean() metrics["alignment_rmsd_l_by_r_under_2"] = torch.mean((l_by_r_alignment_rmsds < 2).float()) metrics["alignment_rmsd_l_by_r_under_5"] = torch.mean((l_by_r_alignment_rmsds < 5).float()) print("ligand rmsd:", l_by_r_alignment_rmsds) return metrics def configure_optimizers(self, learning_rate: Optional[float] = None, eps: float = 1e-5, ) -> torch.optim.Adam: if learning_rate is None: learning_rate = self.config.globals.max_lr optimizer = torch.optim.Adam( self.model.parameters(), lr=learning_rate, eps=eps ) if self.last_lr_step != -1: for group in optimizer.param_groups: if 'initial_lr' not in group: group['initial_lr'] = learning_rate lr_scheduler = AlphaFoldLRScheduler( optimizer, last_epoch=self.last_lr_step, max_lr=self.config.globals.max_lr, start_decay_after_n_steps=10000, decay_every_n_steps=10000, ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, "interval": "step", "name": "AlphaFoldLRScheduler", } } def on_load_checkpoint(self, checkpoint): ema = checkpoint["ema"] self.ema.load_state_dict(ema) def on_save_checkpoint(self, checkpoint): checkpoint["ema"] = self.ema.state_dict() def resume_last_lr_step(self, lr_step): self.last_lr_step = lr_step def override_config(base_config, overriding_config): for k, v in overriding_config.items(): if isinstance(v, dict): base_config[k] = override_config(base_config[k], v) else: base_config[k] = v return base_config def train(override_config_path: str): run_config = json.load(open(override_config_path, "r")) seed = 42 seed_everything(seed, workers=True) output_dir = run_config["train_output_dir"] os.makedirs(output_dir, exist_ok=True) print("Starting train", time.time()) config = model_config( run_config.get("stage", "initial_training"), train=True, low_prec=True ) config = override_config(config, run_config.get("override_conf", {})) accumulate_grad_batches = run_config.get("accumulate_grad_batches", 1) print("config loaded", time.time()) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_name = "cuda" if torch.cuda.is_available() else "cpu" # device_name = "mps" if device_name == "cpu" and torch.backends.mps.is_available() else device_name model_module = ModelWrapper(config) print("model loaded", time.time()) # device_name = "cpu" # for debugging memory: # torch.cuda.memory._record_memory_history() if "train_input_dir" in run_config: data_module = OpenFoldDataModule( config=config.data, batch_seed=seed, train_data_dir=run_config["train_input_dir"], val_data_dir=run_config["val_input_dir"], train_epoch_len=run_config.get("train_epoch_len", 1000), ) else: data_module = DockFormerDataModule( config=config.data, batch_seed=seed, train_data_file=run_config["train_input_file"], val_data_file=run_config["val_input_file"], ) print("data module loaded", time.time()) checkpoint_dir = os.path.join(output_dir, "checkpoint") ckpt_path = run_config.get("ckpt_path", get_latest_checkpoint(checkpoint_dir)) if ckpt_path: print(f"Resuming from checkpoint: {ckpt_path}") sd = torch.load(ckpt_path) last_global_step = int(sd['global_step']) model_module.resume_last_lr_step(last_global_step) # Do we need this? data_module.prepare_data() data_module.setup("fit") callbacks = [] mc = ModelCheckpoint( dirpath=checkpoint_dir, # every_n_epochs=1, every_n_train_steps=250, auto_insert_metric_name=False, save_top_k=1, save_on_train_epoch_end=True, # before validation ) mc2 = ModelCheckpoint( dirpath=checkpoint_dir, # Directory to save checkpoints filename="step{step}_rmsd{val/alignment_rmsd_l_by_r:.2f}", # Filename format for best monitor="val/alignment_rmsd_l_by_r", # Metric to monitor mode="min", # We want the lowest `ligand_rmsd` save_top_k=1, # Save only the best model based on `ligand_rmsd` every_n_epochs=1, # Save a checkpoint every epoch auto_insert_metric_name=False ) callbacks.append(mc) callbacks.append(mc2) lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) loggers = [] wandb_project_name = "DockFormerPP" wandb_run_id_path = os.path.join(output_dir, "wandb_run_id.txt") # Initialize WandbLogger and save run_id local_rank = int(os.getenv('LOCAL_RANK', os.getenv("SLURM_PROCID", '0'))) global_rank = int(os.getenv('GLOBAL_RANK', os.getenv("SLURM_LOCALID", '0'))) print("ranks", os.getenv('LOCAL_RANK', 'd0'), os.getenv('local_rank', 'd0'), os.getenv('GLOBAL_RANK', 'd0'), os.getenv('global_rank', 'd0'), os.getenv("SLURM_PROCID", 'd0'), os.getenv('SLURM_LOCALID', 'd0'), flush=True) if local_rank == 0 and global_rank == 0 and not os.path.exists(wandb_run_id_path): wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir) with open(wandb_run_id_path, 'w') as f: f.write(wandb_logger.experiment.id) wandb_logger.experiment.config.update(run_config, allow_val_change=True) else: # Necessary for multi-node training https://github.com/rstrudel/segmenter/issues/22 while not os.path.exists(wandb_run_id_path): print(f"Waiting for run_id file to be created ({local_rank})", flush=True) time.sleep(1) with open(wandb_run_id_path, 'r') as f: run_id = f.read().strip() wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir, resume='must', id=run_id) loggers.append(wandb_logger) strategy_params = {"strategy": "auto"} if run_config.get("multi_node", False): strategy_params["strategy"] = "ddp" # strategy_params["strategy"] = "ddp_find_unused_parameters_true" # this causes issues with checkpointing... strategy_params["num_nodes"] = run_config["multi_node"]["num_nodes"] strategy_params["devices"] = run_config["multi_node"]["devices"] trainer = pl.Trainer( accelerator=device_name, default_root_dir=output_dir, **strategy_params, reload_dataloaders_every_n_epochs=1, accumulate_grad_batches=accumulate_grad_batches, check_val_every_n_epoch=run_config.get("check_val_every_n_epoch", 10), callbacks=callbacks, logger=loggers, # profiler=AdvancedProfiler(), ) print("Starting fit", time.time()) trainer.fit( model_module, datamodule=data_module, ckpt_path=ckpt_path, ) # profiler_results = trainer.profiler.summary() # print(profiler_results) # torch.cuda.memory._dump_snapshot("my_train_snapshot.pickle") # view on https://pytorch.org/memory_viz if __name__ == "__main__": if len(sys.argv) > 1: train(sys.argv[1]) else: train(os.path.join(os.path.dirname(__file__), "run_config.json"))