FoldMark / runner /train.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import os
import time
from contextlib import nullcontext
import torch
import torch.distributed as dist
import wandb
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from configs.configs_base import configs as configs_base
from configs.configs_data import data_configs
from protenix.config import parse_configs, parse_sys_args
from protenix.config.config import save_config
from protenix.data.dataloader import get_dataloaders
from protenix.metrics.lddt_metrics import LDDTMetrics
from protenix.model.loss import ProtenixLoss
from protenix.model.protenix_edit import Protenix
from protenix.utils.distributed import DIST_WRAPPER
from protenix.utils.lr_scheduler import get_lr_scheduler
from protenix.utils.metrics import SimpleMetricAggregator
from protenix.utils.permutation.permutation import SymmetricPermutation
from protenix.utils.seed import seed_everything
from protenix.utils.torch_utils import autocasting_disable_decorator, to_device
from protenix.utils.training import get_optimizer, is_loss_nan_check
from runner.ema import EMAWrapper
# Disable WANDB's console output capture to reduce unnecessary logging
os.environ["WANDB_CONSOLE"] = "off"
class AF3Trainer(object):
def __init__(self, configs):
self.configs = configs
self.init_env()
self.init_basics()
self.init_log()
self.init_model()
self.init_loss()
self.init_data()
self.try_load_checkpoint()
def init_basics(self):
# Step means effective step considering accumulation
self.step = 0
# Global_step equals to self.step * self.iters_to_accumulate
self.global_step = 0
self.start_step = 0
# Add for grad accumulation, it can increase real batch size
self.iters_to_accumulate = self.configs.iters_to_accumulate
self.run_name = self.configs.run_name + "_" + time.strftime("%Y%m%d_%H%M%S")
run_names = DIST_WRAPPER.all_gather_object(
self.run_name if DIST_WRAPPER.rank == 0 else None
)
self.run_name = [name for name in run_names if name is not None][0]
self.run_dir = f"{self.configs.base_dir}/{self.run_name}"
self.checkpoint_dir = f"{self.run_dir}/checkpoints"
self.prediction_dir = f"{self.run_dir}/predictions"
self.structure_dir = f"{self.run_dir}/structures"
self.dump_dir = f"{self.run_dir}/dumps"
self.error_dir = f"{self.run_dir}/errors"
if DIST_WRAPPER.rank == 0:
os.makedirs(self.run_dir)
os.makedirs(self.checkpoint_dir)
os.makedirs(self.prediction_dir)
os.makedirs(self.structure_dir)
os.makedirs(self.dump_dir)
os.makedirs(self.error_dir)
save_config(
self.configs,
os.path.join(self.configs.base_dir, self.run_name, "config.yaml"),
)
self.print(
f"Using run name: {self.run_name}, run dir: {self.run_dir}, checkpoint_dir: "
+ f"{self.checkpoint_dir}, prediction_dir: {self.prediction_dir}, structure_dir: "
+ f"{self.structure_dir}, error_dir: {self.error_dir}"
)
def init_log(self):
if self.configs.use_wandb and DIST_WRAPPER.rank == 0:
wandb.init(
project=self.configs.project,
name=self.run_name,
config=vars(self.configs),
id=self.configs.wandb_id or None,
)
self.train_metric_wrapper = SimpleMetricAggregator(["avg"])
def init_env(self):
"""Init pytorch/cuda envs."""
logging.info(
f"Distributed environment: world size: {DIST_WRAPPER.world_size}, "
+ f"global rank: {DIST_WRAPPER.rank}, local rank: {DIST_WRAPPER.local_rank}"
)
self.use_cuda = torch.cuda.device_count() > 0
if self.use_cuda:
self.device = torch.device("cuda:{}".format(DIST_WRAPPER.local_rank))
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count()))
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
logging.info(
f"LOCAL_RANK: {DIST_WRAPPER.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]"
)
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
if DIST_WRAPPER.world_size > 1:
timeout_seconds = int(os.environ.get("NCCL_TIMEOUT_SECOND", 600))
dist.init_process_group(
backend="nccl", timeout=datetime.timedelta(seconds=timeout_seconds)
)
# All ddp process got the same seed
seed_everything(
seed=self.configs.seed,
deterministic=self.configs.deterministic,
)
if self.configs.use_deepspeed_evo_attention:
env = os.getenv("CUTLASS_PATH", None)
print(f"env: {env}")
assert (
env is not None
), "if use ds4sci, set env as https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/"
logging.info("Finished init ENV.")
def init_loss(self):
self.loss = ProtenixLoss(self.configs)
self.symmetric_permutation = SymmetricPermutation(
self.configs, error_dir=self.error_dir
)
self.lddt_metrics = LDDTMetrics(self.configs)
def init_model(self):
self.raw_model = Protenix(self.configs).to(self.device)
self.use_ddp = False
if DIST_WRAPPER.world_size > 1:
self.print(f"Using DDP")
self.use_ddp = True
# Fix DDP/checkpoint https://discuss.pytorch.org/t/ddp-and-gradient-checkpointing/132244
self.model = DDP(
self.raw_model,
find_unused_parameters=self.configs.find_unused_parameters,
device_ids=[DIST_WRAPPER.local_rank],
output_device=DIST_WRAPPER.local_rank,
static_graph=True,
)
else:
self.model = self.raw_model
if self.configs.get("ema_decay", -1) > 0:
assert self.configs.ema_decay < 1
self.ema_wrapper = EMAWrapper(
self.model,
self.configs.ema_decay,
self.configs.ema_mutable_param_keywords,
)
self.ema_wrapper.register()
torch.cuda.empty_cache()
self.optimizer = get_optimizer(self.configs, self.model)
self.init_scheduler()
def init_scheduler(self, **kwargs):
self.lr_scheduler = get_lr_scheduler(self.configs, self.optimizer, **kwargs)
def init_data(self):
self.train_dl, self.test_dls = get_dataloaders(
self.configs,
DIST_WRAPPER.world_size,
seed=self.configs.seed,
error_dir=self.error_dir,
)
def save_checkpoint(self, ema_suffix=""):
if DIST_WRAPPER.rank == 0:
path = f"{self.checkpoint_dir}/{self.step}{ema_suffix}.pt"
checkpoint = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": (
self.lr_scheduler.state_dict()
if self.lr_scheduler is not None
else None
),
"step": self.step,
}
torch.save(checkpoint, path)
self.print(f"Saved checkpoint to {path}")
def try_load_checkpoint(self):
def _load_checkpoint(
checkpoint_path: str,
load_params_only: bool,
skip_load_optimizer: bool = False,
skip_load_step: bool = False,
skip_load_scheduler: bool = False,
):
if not os.path.exists(checkpoint_path):
raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]")
self.print(
f"Loading from {checkpoint_path}, strict: {self.configs.load_strict}"
)
checkpoint = torch.load(checkpoint_path, self.device)
sample_key = [k for k in checkpoint["model"].keys()][0]
self.print(f"Sampled key: {sample_key}")
if sample_key.startswith("module.") and not self.use_ddp:
# DDP checkpoint has module. prefix
checkpoint["model"] = {
k[len("module.") :]: v for k, v in checkpoint["model"].items()
}
self.model.load_state_dict(
state_dict=checkpoint["model"],
strict=self.configs.load_strict,
)
if not load_params_only:
if not skip_load_optimizer:
self.print(f"Loading optimizer state")
self.optimizer.load_state_dict(checkpoint["optimizer"])
if not skip_load_step:
self.print(f"Loading checkpoint step")
self.step = checkpoint["step"] + 1
self.start_step = self.step
self.global_step = self.step * self.iters_to_accumulate
if not skip_load_scheduler:
self.print(f"Loading scheduler state")
self.lr_scheduler.load_state_dict(checkpoint["scheduler"])
else:
# reinitialize LR scheduler using the updated optimizer and step
self.init_scheduler(last_epoch=self.step - 1)
self.print(f"Finish loading checkpoint, current step: {self.step}")
# Load EMA model parameters
if self.configs.load_ema_checkpoint_path:
_load_checkpoint(
self.configs.load_ema_checkpoint_path,
load_params_only=True,
)
self.ema_wrapper.register()
# Load model
if self.configs.load_checkpoint_path:
_load_checkpoint(
self.configs.load_checkpoint_path,
self.configs.load_params_only,
skip_load_optimizer=self.configs.skip_load_optimizer,
skip_load_scheduler=self.configs.skip_load_scheduler,
skip_load_step=self.configs.skip_load_step,
)
def print(self, msg: str):
if DIST_WRAPPER.rank == 0:
logging.info(msg)
def model_forward(self, batch: dict, mode: str = "train") -> tuple[dict, dict]:
assert mode in ["train", "eval"]
batch["pred_dict"], batch["label_dict"], log_dict = self.model(
input_feature_dict=batch["input_feature_dict"],
label_dict=batch["label_dict"],
label_full_dict=batch["label_full_dict"],
mode=mode,
current_step=self.step if mode == "train" else None,
symmetric_permutation=self.symmetric_permutation,
)
return batch, log_dict
def get_loss(
self, batch: dict, mode: str = "train"
) -> tuple[torch.Tensor, dict, dict]:
assert mode in ["train", "eval"]
loss, loss_dict = autocasting_disable_decorator(self.configs.skip_amp.loss)(
self.loss
)(
feat_dict=batch["input_feature_dict"],
pred_dict=batch["pred_dict"],
label_dict=batch["label_dict"],
mode=mode,
)
return loss, loss_dict, batch
@torch.no_grad()
def get_metrics(self, batch: dict) -> dict:
lddt_dict = self.lddt_metrics.compute_lddt(
batch["pred_dict"], batch["label_dict"]
)
return lddt_dict
@torch.no_grad()
def aggregate_metrics(self, lddt_dict: dict, batch: dict) -> dict:
simple_metrics, _ = self.lddt_metrics.aggregate_lddt(
lddt_dict, batch["pred_dict"]["summary_confidence"]
)
return simple_metrics
@torch.no_grad()
def get_recovery(self, pred_code, gt_code):
pred_code = torch.clamp(pred_code, min=-10, max=10)
epsilon = 1e-6
predicted_classes = (torch.sigmoid(pred_code) > 0.5).float()
# Use the epsilon to ensure no division by zero in recovery computation
recovery = ((predicted_classes == gt_code.float()).float().mean(dim=-1)).clamp(min=epsilon)
return recovery.mean()
@torch.no_grad()
def evaluate(self, mode: str = "eval"):
if not self.configs.eval_ema_only:
self._evaluate()
if hasattr(self, "ema_wrapper"):
self.ema_wrapper.apply_shadow()
self._evaluate(ema_suffix=f"ema{self.ema_wrapper.decay}_", mode=mode)
self.ema_wrapper.restore()
@torch.no_grad()
def _evaluate(self, ema_suffix: str = "", mode: str = "eval"):
# Init Metric Aggregator
simple_metric_wrapper = SimpleMetricAggregator(["avg"])
eval_precision = {
"fp32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[self.configs.dtype]
enable_amp = (
torch.autocast(device_type="cuda", dtype=eval_precision)
if torch.cuda.is_available()
else nullcontext()
)
self.model.eval()
for test_name, test_dl in self.test_dls.items():
self.print(f"Testing on {test_name}")
evaluated_pids = []
total_batch_num = len(test_dl)
for index, batch in enumerate(tqdm(test_dl)):
batch = to_device(batch, self.device)
pid = batch["basic"]["pdb_id"]
if index + 1 == total_batch_num and DIST_WRAPPER.world_size > 1:
# Gather all pids across ranks for avoiding duplicated evaluations when drop_last = False
all_data_ids = DIST_WRAPPER.all_gather_object(evaluated_pids)
dedup_ids = set(sum(all_data_ids, []))
if pid in dedup_ids:
print(
f"Rank {DIST_WRAPPER.rank}: Drop data_id {pid} as it is already evaluated."
)
break
evaluated_pids.append(pid)
simple_metrics = {}
with enable_amp:
# Model forward
batch, _ = self.model_forward(batch, mode=mode)
# Loss forward
loss, loss_dict, batch = self.get_loss(batch, mode="eval")
# lDDT metrics
lddt_dict = self.get_metrics(batch)
lddt_metrics = self.aggregate_metrics(lddt_dict, batch)
simple_metrics.update(
{k: v for k, v in lddt_metrics.items() if "diff" not in k}
)
simple_metrics.update(loss_dict)
# Metrics
for key, value in simple_metrics.items():
simple_metric_wrapper.add(
f"{ema_suffix}{key}", value, namespace=test_name
)
del batch, simple_metrics
if index % 5 == 0:
# Release some memory periodically
torch.cuda.empty_cache()
metrics = simple_metric_wrapper.calc()
self.print(f"Step {self.step}, eval {test_name}: {metrics}")
if self.configs.use_wandb and DIST_WRAPPER.rank == 0:
wandb.log(metrics, step=self.step)
def update(self):
# Clip the gradient
if self.configs.grad_clip_norm != 0.0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.configs.grad_clip_norm
)
def train_step(self, batch: dict):
self.model.train()
# FP16 training has not been verified yet
train_precision = {
"fp32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[self.configs.dtype]
enable_amp = (
torch.autocast(
device_type="cuda", dtype=train_precision, cache_enabled=False
)
if torch.cuda.is_available()
else nullcontext()
)
scaler = torch.GradScaler(
device="cuda" if torch.cuda.is_available() else "cpu",
enabled=(self.configs.dtype == "float16"),
)
with enable_amp:
batch, _ = self.model_forward(batch, mode="train")
loss, loss_dict, _ = self.get_loss(batch, mode="train")
recovery = self.get_recovery(batch["pred_dict"]['watermark'], batch["label_dict"]['watermark'])
if self.configs.dtype in ["bf16", "fp32"]:
if is_loss_nan_check(loss):
self.print(f"Skip iteration with NaN loss: {self.step} steps")
loss = torch.tensor(0.0, device=loss.device, requires_grad=True)
scaler.scale(loss / self.iters_to_accumulate).backward()
# For simplicity, the global training step is used
if (self.global_step + 1) % self.iters_to_accumulate == 0:
self.print(
f"self.step {self.step}, self.iters_to_accumulate: {self.iters_to_accumulate}"
)
# Unscales the gradients of optimizer's assigned parameters in-place
scaler.unscale_(self.optimizer)
# Do grad clip only
self.update()
scaler.step(self.optimizer)
scaler.update()
self.optimizer.zero_grad(set_to_none=True)
self.lr_scheduler.step()
for key, value in loss_dict.items():
if "loss" not in key:
continue
self.train_metric_wrapper.add(key, value, namespace="train")
self.train_metric_wrapper.add('recovery', recovery, namespace="train")
torch.cuda.empty_cache()
def progress_bar(self, desc: str = ""):
if DIST_WRAPPER.rank != 0:
return
if self.global_step % (
self.configs.eval_interval * self.iters_to_accumulate
) == 0 or (not hasattr(self, "_ipbar")):
# Start a new progress bar
self._pbar = tqdm(
range(
self.global_step
% (self.iters_to_accumulate * self.configs.eval_interval),
self.iters_to_accumulate * self.configs.eval_interval,
)
)
self._ipbar = iter(self._pbar)
step = next(self._ipbar)
self._pbar.set_description(
f"[step {self.step}: {step}/{self.iters_to_accumulate * self.configs.eval_interval}] {desc}"
)
return
def run(self):
"""
Main entry for the AF3Trainer.
This function handles the training process, evaluation, logging, and checkpoint saving.
"""
if self.configs.eval_only or self.configs.eval_first:
self.evaluate()
if self.configs.eval_only:
return
use_ema = hasattr(self, "ema_wrapper")
self.print(f"Using ema: {use_ema}")
while True:
for batch in self.train_dl:
is_update_step = (self.global_step + 1) % self.iters_to_accumulate == 0
is_last_step = (self.step + 1) == self.configs.max_steps
step_need_log = (self.step + 1) % self.configs.log_interval == 0
step_need_eval = (
self.configs.eval_interval > 0
and (self.step + 1) % self.configs.eval_interval == 0
)
step_need_save = (
self.configs.checkpoint_interval > 0
and (self.step + 1) % self.configs.checkpoint_interval == 0
)
is_last_step &= is_update_step
step_need_log &= is_update_step
step_need_eval &= is_update_step
step_need_save &= is_update_step
batch = to_device(batch, self.device)
self.progress_bar()
self.train_step(batch)
if use_ema:
self.ema_wrapper.update()
if step_need_log or is_last_step:
metrics = self.train_metric_wrapper.calc()
self.print(f"Step {self.step} train: {metrics}")
last_lr = self.lr_scheduler.get_last_lr()[0]
if DIST_WRAPPER.rank == 0:
if self.configs.use_wandb:
wandb.log(
{"train/lr": last_lr},
step=self.step,
)
self.print(f"Step {self.step}, lr: {last_lr}")
if self.configs.use_wandb and DIST_WRAPPER.rank == 0:
wandb.log(metrics, step=self.step)
if step_need_save or is_last_step:
self.save_checkpoint()
if use_ema:
self.ema_wrapper.apply_shadow()
self.save_checkpoint(
ema_suffix=f"_ema_{self.ema_wrapper.decay}"
)
self.ema_wrapper.restore()
if step_need_eval or is_last_step:
self.evaluate()
self.global_step += 1
if self.global_step % self.iters_to_accumulate == 0:
self.step += 1
if self.step >= self.configs.max_steps:
self.print(f"Finish training after {self.step} steps")
break
if self.step >= self.configs.max_steps:
break
def main():
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
logging.basicConfig(
format=LOG_FORMAT,
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
filemode="w",
)
configs_base["use_deepspeed_evo_attention"] = (
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
)
configs = {**configs_base, **{"data": data_configs}}
configs = parse_configs(
configs,
parse_sys_args(),
)
print(configs.run_name)
print(configs)
trainer = AF3Trainer(configs)
trainer.run()
if __name__ == "__main__":
main()