|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
import logging |
|
import os |
|
import time |
|
from contextlib import nullcontext |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import wandb |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from tqdm import tqdm |
|
|
|
from configs.configs_base import configs as configs_base |
|
from configs.configs_data import data_configs |
|
from protenix.config import parse_configs, parse_sys_args |
|
from protenix.config.config import save_config |
|
from protenix.data.dataloader import get_dataloaders |
|
from protenix.metrics.lddt_metrics import LDDTMetrics |
|
from protenix.model.loss import ProtenixLoss |
|
from protenix.model.protenix_edit import Protenix |
|
from protenix.utils.distributed import DIST_WRAPPER |
|
from protenix.utils.lr_scheduler import get_lr_scheduler |
|
from protenix.utils.metrics import SimpleMetricAggregator |
|
from protenix.utils.permutation.permutation import SymmetricPermutation |
|
from protenix.utils.seed import seed_everything |
|
from protenix.utils.torch_utils import autocasting_disable_decorator, to_device |
|
from protenix.utils.training import get_optimizer, is_loss_nan_check |
|
from runner.ema import EMAWrapper |
|
|
|
|
|
os.environ["WANDB_CONSOLE"] = "off" |
|
|
|
|
|
class AF3Trainer(object): |
|
def __init__(self, configs): |
|
self.configs = configs |
|
self.init_env() |
|
self.init_basics() |
|
self.init_log() |
|
self.init_model() |
|
self.init_loss() |
|
self.init_data() |
|
self.try_load_checkpoint() |
|
|
|
def init_basics(self): |
|
|
|
self.step = 0 |
|
|
|
self.global_step = 0 |
|
self.start_step = 0 |
|
|
|
self.iters_to_accumulate = self.configs.iters_to_accumulate |
|
|
|
self.run_name = self.configs.run_name + "_" + time.strftime("%Y%m%d_%H%M%S") |
|
run_names = DIST_WRAPPER.all_gather_object( |
|
self.run_name if DIST_WRAPPER.rank == 0 else None |
|
) |
|
self.run_name = [name for name in run_names if name is not None][0] |
|
self.run_dir = f"{self.configs.base_dir}/{self.run_name}" |
|
self.checkpoint_dir = f"{self.run_dir}/checkpoints" |
|
self.prediction_dir = f"{self.run_dir}/predictions" |
|
self.structure_dir = f"{self.run_dir}/structures" |
|
self.dump_dir = f"{self.run_dir}/dumps" |
|
self.error_dir = f"{self.run_dir}/errors" |
|
|
|
if DIST_WRAPPER.rank == 0: |
|
os.makedirs(self.run_dir) |
|
os.makedirs(self.checkpoint_dir) |
|
os.makedirs(self.prediction_dir) |
|
os.makedirs(self.structure_dir) |
|
os.makedirs(self.dump_dir) |
|
os.makedirs(self.error_dir) |
|
save_config( |
|
self.configs, |
|
os.path.join(self.configs.base_dir, self.run_name, "config.yaml"), |
|
) |
|
|
|
self.print( |
|
f"Using run name: {self.run_name}, run dir: {self.run_dir}, checkpoint_dir: " |
|
+ f"{self.checkpoint_dir}, prediction_dir: {self.prediction_dir}, structure_dir: " |
|
+ f"{self.structure_dir}, error_dir: {self.error_dir}" |
|
) |
|
|
|
def init_log(self): |
|
if self.configs.use_wandb and DIST_WRAPPER.rank == 0: |
|
wandb.init( |
|
project=self.configs.project, |
|
name=self.run_name, |
|
config=vars(self.configs), |
|
id=self.configs.wandb_id or None, |
|
) |
|
self.train_metric_wrapper = SimpleMetricAggregator(["avg"]) |
|
|
|
def init_env(self): |
|
"""Init pytorch/cuda envs.""" |
|
logging.info( |
|
f"Distributed environment: world size: {DIST_WRAPPER.world_size}, " |
|
+ f"global rank: {DIST_WRAPPER.rank}, local rank: {DIST_WRAPPER.local_rank}" |
|
) |
|
self.use_cuda = torch.cuda.device_count() > 0 |
|
if self.use_cuda: |
|
self.device = torch.device("cuda:{}".format(DIST_WRAPPER.local_rank)) |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count())) |
|
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) |
|
logging.info( |
|
f"LOCAL_RANK: {DIST_WRAPPER.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]" |
|
) |
|
torch.cuda.set_device(self.device) |
|
else: |
|
self.device = torch.device("cpu") |
|
if DIST_WRAPPER.world_size > 1: |
|
timeout_seconds = int(os.environ.get("NCCL_TIMEOUT_SECOND", 600)) |
|
dist.init_process_group( |
|
backend="nccl", timeout=datetime.timedelta(seconds=timeout_seconds) |
|
) |
|
|
|
seed_everything( |
|
seed=self.configs.seed, |
|
deterministic=self.configs.deterministic, |
|
) |
|
|
|
if self.configs.use_deepspeed_evo_attention: |
|
env = os.getenv("CUTLASS_PATH", None) |
|
print(f"env: {env}") |
|
assert ( |
|
env is not None |
|
), "if use ds4sci, set env as https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/" |
|
logging.info("Finished init ENV.") |
|
|
|
def init_loss(self): |
|
self.loss = ProtenixLoss(self.configs) |
|
self.symmetric_permutation = SymmetricPermutation( |
|
self.configs, error_dir=self.error_dir |
|
) |
|
self.lddt_metrics = LDDTMetrics(self.configs) |
|
|
|
def init_model(self): |
|
self.raw_model = Protenix(self.configs).to(self.device) |
|
self.use_ddp = False |
|
if DIST_WRAPPER.world_size > 1: |
|
self.print(f"Using DDP") |
|
self.use_ddp = True |
|
|
|
self.model = DDP( |
|
self.raw_model, |
|
find_unused_parameters=self.configs.find_unused_parameters, |
|
device_ids=[DIST_WRAPPER.local_rank], |
|
output_device=DIST_WRAPPER.local_rank, |
|
static_graph=True, |
|
) |
|
else: |
|
self.model = self.raw_model |
|
|
|
if self.configs.get("ema_decay", -1) > 0: |
|
assert self.configs.ema_decay < 1 |
|
self.ema_wrapper = EMAWrapper( |
|
self.model, |
|
self.configs.ema_decay, |
|
self.configs.ema_mutable_param_keywords, |
|
) |
|
self.ema_wrapper.register() |
|
|
|
torch.cuda.empty_cache() |
|
self.optimizer = get_optimizer(self.configs, self.model) |
|
self.init_scheduler() |
|
|
|
def init_scheduler(self, **kwargs): |
|
self.lr_scheduler = get_lr_scheduler(self.configs, self.optimizer, **kwargs) |
|
|
|
def init_data(self): |
|
self.train_dl, self.test_dls = get_dataloaders( |
|
self.configs, |
|
DIST_WRAPPER.world_size, |
|
seed=self.configs.seed, |
|
error_dir=self.error_dir, |
|
) |
|
|
|
def save_checkpoint(self, ema_suffix=""): |
|
if DIST_WRAPPER.rank == 0: |
|
path = f"{self.checkpoint_dir}/{self.step}{ema_suffix}.pt" |
|
checkpoint = { |
|
"model": self.model.state_dict(), |
|
"optimizer": self.optimizer.state_dict(), |
|
"scheduler": ( |
|
self.lr_scheduler.state_dict() |
|
if self.lr_scheduler is not None |
|
else None |
|
), |
|
"step": self.step, |
|
} |
|
torch.save(checkpoint, path) |
|
self.print(f"Saved checkpoint to {path}") |
|
|
|
def try_load_checkpoint(self): |
|
|
|
def _load_checkpoint( |
|
checkpoint_path: str, |
|
load_params_only: bool, |
|
skip_load_optimizer: bool = False, |
|
skip_load_step: bool = False, |
|
skip_load_scheduler: bool = False, |
|
): |
|
if not os.path.exists(checkpoint_path): |
|
raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]") |
|
self.print( |
|
f"Loading from {checkpoint_path}, strict: {self.configs.load_strict}" |
|
) |
|
checkpoint = torch.load(checkpoint_path, self.device) |
|
sample_key = [k for k in checkpoint["model"].keys()][0] |
|
self.print(f"Sampled key: {sample_key}") |
|
if sample_key.startswith("module.") and not self.use_ddp: |
|
|
|
checkpoint["model"] = { |
|
k[len("module.") :]: v for k, v in checkpoint["model"].items() |
|
} |
|
|
|
self.model.load_state_dict( |
|
state_dict=checkpoint["model"], |
|
strict=self.configs.load_strict, |
|
) |
|
if not load_params_only: |
|
if not skip_load_optimizer: |
|
self.print(f"Loading optimizer state") |
|
self.optimizer.load_state_dict(checkpoint["optimizer"]) |
|
if not skip_load_step: |
|
self.print(f"Loading checkpoint step") |
|
self.step = checkpoint["step"] + 1 |
|
self.start_step = self.step |
|
self.global_step = self.step * self.iters_to_accumulate |
|
if not skip_load_scheduler: |
|
self.print(f"Loading scheduler state") |
|
self.lr_scheduler.load_state_dict(checkpoint["scheduler"]) |
|
else: |
|
|
|
self.init_scheduler(last_epoch=self.step - 1) |
|
self.print(f"Finish loading checkpoint, current step: {self.step}") |
|
|
|
|
|
if self.configs.load_ema_checkpoint_path: |
|
_load_checkpoint( |
|
self.configs.load_ema_checkpoint_path, |
|
load_params_only=True, |
|
) |
|
self.ema_wrapper.register() |
|
|
|
|
|
if self.configs.load_checkpoint_path: |
|
_load_checkpoint( |
|
self.configs.load_checkpoint_path, |
|
self.configs.load_params_only, |
|
skip_load_optimizer=self.configs.skip_load_optimizer, |
|
skip_load_scheduler=self.configs.skip_load_scheduler, |
|
skip_load_step=self.configs.skip_load_step, |
|
) |
|
|
|
def print(self, msg: str): |
|
if DIST_WRAPPER.rank == 0: |
|
logging.info(msg) |
|
|
|
def model_forward(self, batch: dict, mode: str = "train") -> tuple[dict, dict]: |
|
assert mode in ["train", "eval"] |
|
batch["pred_dict"], batch["label_dict"], log_dict = self.model( |
|
input_feature_dict=batch["input_feature_dict"], |
|
label_dict=batch["label_dict"], |
|
label_full_dict=batch["label_full_dict"], |
|
mode=mode, |
|
current_step=self.step if mode == "train" else None, |
|
symmetric_permutation=self.symmetric_permutation, |
|
) |
|
return batch, log_dict |
|
|
|
def get_loss( |
|
self, batch: dict, mode: str = "train" |
|
) -> tuple[torch.Tensor, dict, dict]: |
|
assert mode in ["train", "eval"] |
|
|
|
loss, loss_dict = autocasting_disable_decorator(self.configs.skip_amp.loss)( |
|
self.loss |
|
)( |
|
feat_dict=batch["input_feature_dict"], |
|
pred_dict=batch["pred_dict"], |
|
label_dict=batch["label_dict"], |
|
mode=mode, |
|
) |
|
return loss, loss_dict, batch |
|
|
|
@torch.no_grad() |
|
def get_metrics(self, batch: dict) -> dict: |
|
|
|
lddt_dict = self.lddt_metrics.compute_lddt( |
|
batch["pred_dict"], batch["label_dict"] |
|
) |
|
|
|
return lddt_dict |
|
|
|
@torch.no_grad() |
|
def aggregate_metrics(self, lddt_dict: dict, batch: dict) -> dict: |
|
|
|
simple_metrics, _ = self.lddt_metrics.aggregate_lddt( |
|
lddt_dict, batch["pred_dict"]["summary_confidence"] |
|
) |
|
|
|
return simple_metrics |
|
|
|
@torch.no_grad() |
|
def get_recovery(self, pred_code, gt_code): |
|
pred_code = torch.clamp(pred_code, min=-10, max=10) |
|
epsilon = 1e-6 |
|
predicted_classes = (torch.sigmoid(pred_code) > 0.5).float() |
|
|
|
recovery = ((predicted_classes == gt_code.float()).float().mean(dim=-1)).clamp(min=epsilon) |
|
return recovery.mean() |
|
|
|
@torch.no_grad() |
|
def evaluate(self, mode: str = "eval"): |
|
if not self.configs.eval_ema_only: |
|
self._evaluate() |
|
if hasattr(self, "ema_wrapper"): |
|
self.ema_wrapper.apply_shadow() |
|
self._evaluate(ema_suffix=f"ema{self.ema_wrapper.decay}_", mode=mode) |
|
self.ema_wrapper.restore() |
|
|
|
@torch.no_grad() |
|
def _evaluate(self, ema_suffix: str = "", mode: str = "eval"): |
|
|
|
simple_metric_wrapper = SimpleMetricAggregator(["avg"]) |
|
eval_precision = { |
|
"fp32": torch.float32, |
|
"bf16": torch.bfloat16, |
|
"fp16": torch.float16, |
|
}[self.configs.dtype] |
|
enable_amp = ( |
|
torch.autocast(device_type="cuda", dtype=eval_precision) |
|
if torch.cuda.is_available() |
|
else nullcontext() |
|
) |
|
self.model.eval() |
|
|
|
for test_name, test_dl in self.test_dls.items(): |
|
self.print(f"Testing on {test_name}") |
|
evaluated_pids = [] |
|
total_batch_num = len(test_dl) |
|
for index, batch in enumerate(tqdm(test_dl)): |
|
batch = to_device(batch, self.device) |
|
pid = batch["basic"]["pdb_id"] |
|
|
|
if index + 1 == total_batch_num and DIST_WRAPPER.world_size > 1: |
|
|
|
all_data_ids = DIST_WRAPPER.all_gather_object(evaluated_pids) |
|
dedup_ids = set(sum(all_data_ids, [])) |
|
if pid in dedup_ids: |
|
print( |
|
f"Rank {DIST_WRAPPER.rank}: Drop data_id {pid} as it is already evaluated." |
|
) |
|
break |
|
evaluated_pids.append(pid) |
|
|
|
simple_metrics = {} |
|
with enable_amp: |
|
|
|
batch, _ = self.model_forward(batch, mode=mode) |
|
|
|
loss, loss_dict, batch = self.get_loss(batch, mode="eval") |
|
|
|
lddt_dict = self.get_metrics(batch) |
|
lddt_metrics = self.aggregate_metrics(lddt_dict, batch) |
|
simple_metrics.update( |
|
{k: v for k, v in lddt_metrics.items() if "diff" not in k} |
|
) |
|
simple_metrics.update(loss_dict) |
|
|
|
|
|
for key, value in simple_metrics.items(): |
|
simple_metric_wrapper.add( |
|
f"{ema_suffix}{key}", value, namespace=test_name |
|
) |
|
|
|
del batch, simple_metrics |
|
if index % 5 == 0: |
|
|
|
torch.cuda.empty_cache() |
|
|
|
metrics = simple_metric_wrapper.calc() |
|
self.print(f"Step {self.step}, eval {test_name}: {metrics}") |
|
if self.configs.use_wandb and DIST_WRAPPER.rank == 0: |
|
wandb.log(metrics, step=self.step) |
|
|
|
def update(self): |
|
|
|
if self.configs.grad_clip_norm != 0.0: |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.configs.grad_clip_norm |
|
) |
|
|
|
def train_step(self, batch: dict): |
|
self.model.train() |
|
|
|
train_precision = { |
|
"fp32": torch.float32, |
|
"bf16": torch.bfloat16, |
|
"fp16": torch.float16, |
|
}[self.configs.dtype] |
|
enable_amp = ( |
|
torch.autocast( |
|
device_type="cuda", dtype=train_precision, cache_enabled=False |
|
) |
|
if torch.cuda.is_available() |
|
else nullcontext() |
|
) |
|
|
|
scaler = torch.GradScaler( |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
enabled=(self.configs.dtype == "float16"), |
|
) |
|
|
|
with enable_amp: |
|
batch, _ = self.model_forward(batch, mode="train") |
|
loss, loss_dict, _ = self.get_loss(batch, mode="train") |
|
recovery = self.get_recovery(batch["pred_dict"]['watermark'], batch["label_dict"]['watermark']) |
|
|
|
if self.configs.dtype in ["bf16", "fp32"]: |
|
if is_loss_nan_check(loss): |
|
self.print(f"Skip iteration with NaN loss: {self.step} steps") |
|
loss = torch.tensor(0.0, device=loss.device, requires_grad=True) |
|
scaler.scale(loss / self.iters_to_accumulate).backward() |
|
|
|
|
|
if (self.global_step + 1) % self.iters_to_accumulate == 0: |
|
self.print( |
|
f"self.step {self.step}, self.iters_to_accumulate: {self.iters_to_accumulate}" |
|
) |
|
|
|
scaler.unscale_(self.optimizer) |
|
|
|
self.update() |
|
scaler.step(self.optimizer) |
|
scaler.update() |
|
self.optimizer.zero_grad(set_to_none=True) |
|
self.lr_scheduler.step() |
|
for key, value in loss_dict.items(): |
|
if "loss" not in key: |
|
continue |
|
self.train_metric_wrapper.add(key, value, namespace="train") |
|
self.train_metric_wrapper.add('recovery', recovery, namespace="train") |
|
torch.cuda.empty_cache() |
|
|
|
def progress_bar(self, desc: str = ""): |
|
if DIST_WRAPPER.rank != 0: |
|
return |
|
if self.global_step % ( |
|
self.configs.eval_interval * self.iters_to_accumulate |
|
) == 0 or (not hasattr(self, "_ipbar")): |
|
|
|
self._pbar = tqdm( |
|
range( |
|
self.global_step |
|
% (self.iters_to_accumulate * self.configs.eval_interval), |
|
self.iters_to_accumulate * self.configs.eval_interval, |
|
) |
|
) |
|
self._ipbar = iter(self._pbar) |
|
|
|
step = next(self._ipbar) |
|
self._pbar.set_description( |
|
f"[step {self.step}: {step}/{self.iters_to_accumulate * self.configs.eval_interval}] {desc}" |
|
) |
|
return |
|
|
|
def run(self): |
|
""" |
|
Main entry for the AF3Trainer. |
|
|
|
This function handles the training process, evaluation, logging, and checkpoint saving. |
|
""" |
|
if self.configs.eval_only or self.configs.eval_first: |
|
self.evaluate() |
|
if self.configs.eval_only: |
|
return |
|
use_ema = hasattr(self, "ema_wrapper") |
|
self.print(f"Using ema: {use_ema}") |
|
|
|
while True: |
|
for batch in self.train_dl: |
|
is_update_step = (self.global_step + 1) % self.iters_to_accumulate == 0 |
|
is_last_step = (self.step + 1) == self.configs.max_steps |
|
step_need_log = (self.step + 1) % self.configs.log_interval == 0 |
|
|
|
step_need_eval = ( |
|
self.configs.eval_interval > 0 |
|
and (self.step + 1) % self.configs.eval_interval == 0 |
|
) |
|
step_need_save = ( |
|
self.configs.checkpoint_interval > 0 |
|
and (self.step + 1) % self.configs.checkpoint_interval == 0 |
|
) |
|
|
|
is_last_step &= is_update_step |
|
step_need_log &= is_update_step |
|
step_need_eval &= is_update_step |
|
step_need_save &= is_update_step |
|
|
|
batch = to_device(batch, self.device) |
|
self.progress_bar() |
|
self.train_step(batch) |
|
if use_ema: |
|
self.ema_wrapper.update() |
|
if step_need_log or is_last_step: |
|
metrics = self.train_metric_wrapper.calc() |
|
self.print(f"Step {self.step} train: {metrics}") |
|
last_lr = self.lr_scheduler.get_last_lr()[0] |
|
if DIST_WRAPPER.rank == 0: |
|
if self.configs.use_wandb: |
|
wandb.log( |
|
{"train/lr": last_lr}, |
|
step=self.step, |
|
) |
|
self.print(f"Step {self.step}, lr: {last_lr}") |
|
if self.configs.use_wandb and DIST_WRAPPER.rank == 0: |
|
wandb.log(metrics, step=self.step) |
|
|
|
if step_need_save or is_last_step: |
|
self.save_checkpoint() |
|
if use_ema: |
|
self.ema_wrapper.apply_shadow() |
|
self.save_checkpoint( |
|
ema_suffix=f"_ema_{self.ema_wrapper.decay}" |
|
) |
|
self.ema_wrapper.restore() |
|
|
|
if step_need_eval or is_last_step: |
|
self.evaluate() |
|
self.global_step += 1 |
|
if self.global_step % self.iters_to_accumulate == 0: |
|
self.step += 1 |
|
if self.step >= self.configs.max_steps: |
|
self.print(f"Finish training after {self.step} steps") |
|
break |
|
if self.step >= self.configs.max_steps: |
|
break |
|
|
|
|
|
def main(): |
|
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" |
|
logging.basicConfig( |
|
format=LOG_FORMAT, |
|
level=logging.INFO, |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
filemode="w", |
|
) |
|
configs_base["use_deepspeed_evo_attention"] = ( |
|
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" |
|
) |
|
configs = {**configs_base, **{"data": data_configs}} |
|
configs = parse_configs( |
|
configs, |
|
parse_sys_args(), |
|
) |
|
|
|
print(configs.run_name) |
|
print(configs) |
|
trainer = AF3Trainer(configs) |
|
trainer.run() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|