FoldMark / protenix /model /protenix_edit.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any, Optional, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from protenix.model import sample_confidence
from protenix.model.generator import (
InferenceNoiseScheduler,
TrainingNoiseSampler,
sample_diffusion,
sample_diffusion_training,
structure_predictor,
watermark_decoder,
)
from protenix.model.utils import simple_merge_dict_list, random_sample_watermark, centre_random_augmentation
from protenix.openfold_local.model.primitives import LayerNorm
from protenix.utils.logger import get_logger
from protenix.utils.permutation.permutation import SymmetricPermutation
from protenix.utils.torch_utils import autocasting_disable_decorator
from .modules.confidence import ConfidenceHead
from .modules.diffusion import DiffusionModule, Struct_decoder, Struct_encoder
from .modules.embedders import InputFeatureEmbedder, RelativePositionEncoding
from .modules.head import DistogramHead
from .modules.pairformer import MSAModule, PairformerStack, TemplateEmbedder
from .modules.primitives import LinearNoBias
logger = get_logger(__name__)
class Protenix(nn.Module):
"""
Implements Algorithm 1 [Main Inference/Train Loop] in AF3
"""
def __init__(self, configs) -> None:
super(Protenix, self).__init__()
self.configs = configs
# Some constants
self.N_cycle = self.configs.model.N_cycle
self.N_model_seed = self.configs.model.N_model_seed
self.train_confidence_only = configs.train_confidence_only
if self.train_confidence_only: # the final finetune stage
assert configs.loss.weight.alpha_diffusion == 0.0
assert configs.loss.weight.alpha_distogram == 0.0
# Diffusion scheduler
self.train_noise_sampler = TrainingNoiseSampler(**configs.train_noise_sampler)
self.inference_noise_scheduler = InferenceNoiseScheduler(
**configs.inference_noise_scheduler
)
self.diffusion_batch_size = self.configs.diffusion_batch_size
# Model
self.input_embedder = InputFeatureEmbedder(**configs.model.input_embedder)
self.relative_position_encoding = RelativePositionEncoding(
**configs.model.relative_position_encoding
)
self.template_embedder = TemplateEmbedder(**configs.model.template_embedder)
self.msa_module = MSAModule(
**configs.model.msa_module,
msa_configs=configs.data.get("msa", {}),
)
self.pairformer_stack = PairformerStack(**configs.model.pairformer)
self.diffusion_module = DiffusionModule(**configs.model.diffusion_module)
self.distogram_head = DistogramHead(**configs.model.distogram_head)
self.confidence_head = ConfidenceHead(**configs.model.confidence_head)
self.pairformer_stack_decoder = PairformerStack(**configs.model.pairformer_decoder) # pairformer stack for decoding with less n_blocks
self.diffusion_module_encoder = Struct_encoder(**configs.model.diffusion_module_encoder_decoder)
self.diffusion_module_decoder = Struct_decoder(**configs.model.diffusion_module_encoder_decoder)
self.code_extractor = nn.Linear(configs.model.diffusion_module.c_token, 1)
self.gating_layer = nn.Linear(configs.model.diffusion_module.c_token, 1)
self.c_s, self.c_z, self.c_s_inputs, self.watermark = (
configs.c_s,
configs.c_z,
configs.c_s_inputs,
configs.watermark,
)
self.linear_no_bias_sinit = LinearNoBias(
in_features=self.c_s_inputs, out_features=self.c_s
)
self.linear_no_bias_zinit1 = LinearNoBias(
in_features=self.c_s, out_features=self.c_z
)
self.linear_no_bias_zinit2 = LinearNoBias(
in_features=self.c_s, out_features=self.c_z
)
self.linear_no_bias_token_bond = LinearNoBias(
in_features=1, out_features=self.c_z
)
self.linear_no_bias_z_cycle = LinearNoBias(
in_features=self.c_z, out_features=self.c_z
)
self.linear_no_bias_s = LinearNoBias(
in_features=self.c_s, out_features=self.c_s
)
self.layernorm_z_cycle = LayerNorm(self.c_z)
self.layernorm_s = LayerNorm(self.c_s)
# Zero init the recycling layer
nn.init.zeros_(self.linear_no_bias_z_cycle.weight)
nn.init.zeros_(self.linear_no_bias_s.weight)
def get_pairformer_output(
self,
pairformer_stack: Callable,
input_feature_dict: dict[str, Any],
N_cycle: int,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
use_msa: Optional[bool] = True,
) -> tuple[torch.Tensor, ...]:
"""
The forward pass from the input to pairformer output
Args:
input_feature_dict (dict[str, Any]): input features
N_cycle (int): number of cycles
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
Tuple[torch.Tensor, ...]: s_inputs, s, z
"""
N_token = input_feature_dict["residue_index"].shape[-1]
if N_token <= 16:
# Deepspeed_evo_attention do not support token <= 16
deepspeed_evo_attention_condition_satisfy = False
else:
deepspeed_evo_attention_condition_satisfy = True
if self.train_confidence_only:
self.input_embedder.eval()
self.template_embedder.eval()
self.msa_module.eval()
self.pairformer_stack.eval()
# Line 1-5
s_inputs = self.input_embedder(
input_feature_dict, inplace_safe=False, chunk_size=chunk_size
) # [..., N_token, 449]
s_init = self.linear_no_bias_sinit(s_inputs) # [..., N_token, c_s]
z_init = (
self.linear_no_bias_zinit1(s_init)[..., None, :]
+ self.linear_no_bias_zinit2(s_init)[..., None, :, :]
) # [..., N_token, N_token, c_z]
if inplace_safe:
z_init += self.relative_position_encoding(input_feature_dict)
z_init += self.linear_no_bias_token_bond(
input_feature_dict["token_bonds"].unsqueeze(dim=-1)
)
else:
z_init = z_init + self.relative_position_encoding(input_feature_dict)
z_init = z_init + self.linear_no_bias_token_bond(
input_feature_dict["token_bonds"].unsqueeze(dim=-1)
)
# Line 6
z = torch.zeros_like(z_init)
s = torch.zeros_like(s_init)
# Line 7-13 recycling
for cycle_no in range(N_cycle):
with torch.set_grad_enabled(
self.training
and (not self.train_confidence_only)
and cycle_no == (N_cycle - 1)
):
z = z_init + self.linear_no_bias_z_cycle(self.layernorm_z_cycle(z))
if inplace_safe:
if self.template_embedder.n_blocks > 0:
z += self.template_embedder(
input_feature_dict,
z,
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel,
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention
and deepspeed_evo_attention_condition_satisfy,
use_lma=self.configs.use_lma,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
if use_msa:
z = self.msa_module(
input_feature_dict,
z,
s_inputs,
pair_mask=None,
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel,
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention
and deepspeed_evo_attention_condition_satisfy,
use_lma=self.configs.use_lma,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
else:
if self.template_embedder.n_blocks > 0:
z = z + self.template_embedder(
input_feature_dict,
z,
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel,
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention
and deepspeed_evo_attention_condition_satisfy,
use_lma=self.configs.use_lma,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
if use_msa:
z = self.msa_module(
input_feature_dict,
z,
s_inputs,
pair_mask=None,
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel,
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention
and deepspeed_evo_attention_condition_satisfy,
use_lma=self.configs.use_lma,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
s = s_init + self.linear_no_bias_s(self.layernorm_s(s))
s, z = self.pairformer_stack(
s,
z,
pair_mask=None,
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel,
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention
and deepspeed_evo_attention_condition_satisfy,
use_lma=self.configs.use_lma,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
if self.train_confidence_only:
self.input_embedder.train()
self.template_embedder.train()
self.msa_module.train()
self.pairformer_stack.train()
return s_inputs, s, z
def sample_diffusion(self, **kwargs) -> torch.Tensor:
"""
Samples diffusion process based on the provided configurations.
Returns:
torch.Tensor: The result of the diffusion sampling process.
"""
_configs = {
key: self.configs.sample_diffusion.get(key)
for key in [
"gamma0",
"gamma_min",
"noise_scale_lambda",
"step_scale_eta",
]
}
_configs.update(
{
"attn_chunk_size": (
self.configs.infer_setting.chunk_size if not self.training else None
),
"diffusion_chunk_size": (
self.configs.infer_setting.sample_diffusion_chunk_size
if not self.training
else None
),
}
)
return autocasting_disable_decorator(self.configs.skip_amp.sample_diffusion)(
sample_diffusion
)(**_configs, **kwargs)
def run_confidence_head(self, *args, **kwargs):
"""
Runs the confidence head with optional automatic mixed precision (AMP) disabled.
Returns:
Any: The output of the confidence head.
"""
return autocasting_disable_decorator(self.configs.skip_amp.confidence_head)(
self.confidence_head
)(*args, **kwargs)
def main_detection_loop(
self,
input_feature_dict: dict[str, Any],
label_dict: dict[str, Any],
N_cycle: int,
mode: str,
inplace_safe: bool = True,
chunk_size: Optional[int] = 4,
N_model_seed: int = 1,
symmetric_permutation: SymmetricPermutation = None,
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Main inference loop (multiple model seeds) for the Alphafold3 model.
Args:
input_feature_dict (dict[str, Any]): Input features dictionary.
label_dict (dict[str, Any]): Label dictionary.
N_cycle (int): Number of cycles.
mode (str): Mode of operation (e.g., 'inference').
inplace_safe (bool): Whether to use inplace operations safely. Defaults to True.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to 4.
N_model_seed (int): Number of model seeds. Defaults to 1.
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None.
Returns:
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries.
"""
pred_dicts = []
log_dicts = []
time_trackers = []
# use ones as
N_sample = self.configs.sample_diffusion["N_sample"]
#label_dict['watermark']=torch.ones(1, 1).to(input_feature_dict["restype"].device)
for _ in range(N_model_seed):
pred_dict, log_dict, time_tracker = self._main_detection_loop(
input_feature_dict=input_feature_dict,
label_dict=label_dict,
N_cycle=N_cycle,
mode=mode,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
symmetric_permutation=symmetric_permutation,
)
pred_dicts.append(pred_dict)
log_dicts.append(log_dict)
time_trackers.append(time_tracker)
# Combine outputs of multiple models
def _cat(dict_list, key):
return torch.cat([x[key] for x in dict_list], dim=0)
def _list_join(dict_list, key):
return sum([x[key] for x in dict_list], [])
all_pred_dict = {
"watermark": _cat(pred_dicts, "watermark")
}
all_log_dict = simple_merge_dict_list(log_dicts)
all_time_dict = simple_merge_dict_list(time_trackers)
return all_pred_dict, label_dict, all_log_dict, all_time_dict
def _main_detection_loop(
self,
input_feature_dict: dict[str, Any],
label_dict: dict[str, Any],
N_cycle: int,
mode: str,
inplace_safe: bool = True,
chunk_size: Optional[int] = 4,
symmetric_permutation: SymmetricPermutation = None,
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Main inference loop (single model seed) for the Alphafold3 model.
Returns:
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries.
"""
step_st = time.time()
N_token = input_feature_dict["residue_index"].shape[-1]
if N_token <= 16:
deepspeed_evo_attention_condition_satisfy = False
else:
deepspeed_evo_attention_condition_satisfy = True
log_dict = {}
pred_dict = {}
time_tracker = {}
# Watermark detection
s_inputs_clean, s_clean, z_clean = self.get_pairformer_output(
pairformer_stack=self.pairformer_stack_decoder,
input_feature_dict=input_feature_dict,
N_cycle=N_cycle,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
use_msa=False,
)
if mode == "inference":
keys_to_delete = []
for key in input_feature_dict.keys():
if "template_" in key or key in [
"msa",
"has_deletion",
"deletion_value",
"profile",
"deletion_mean",
"token_bonds",
]:
keys_to_delete.append(key)
for key in keys_to_delete:
del input_feature_dict[key]
torch.cuda.empty_cache()
step_trunk = time.time()
time_tracker.update({"pairformer": step_trunk - step_st})
# Sample diffusion
# [..., N_sample, N_atom, 3]
# watermark detector
_, a_token, x_noise_level = autocasting_disable_decorator(
self.configs.skip_amp.sample_diffusion_training
)(watermark_decoder)(
coordinate=label_dict['coordinate'].unsqueeze(0),
denoise_net=self.diffusion_module_decoder,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs_clean,
s_trunk=s_clean,
z_trunk=z_clean,
N_sample=1,
diffusion_chunk_size=self.configs.diffusion_chunk_size,
)
scores = self.gating_layer(a_token)
weights = F.softmax(scores, dim=-2)
extracted = self.code_extractor(a_token)
watermark = (extracted * weights).sum(dim=-2)
pred_dict.update(
{
"watermark": watermark,
}
)
time_tracker.update({"model_forward": time.time() - step_st})
return pred_dict, log_dict, time_tracker
def main_inference_loop(
self,
input_feature_dict: dict[str, Any],
label_dict: dict[str, Any],
N_cycle: int,
mode: str,
inplace_safe: bool = True,
chunk_size: Optional[int] = 4,
N_model_seed: int = 1,
symmetric_permutation: SymmetricPermutation = None,
watermark=False
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Main inference loop (multiple model seeds) for the Alphafold3 model.
Args:
input_feature_dict (dict[str, Any]): Input features dictionary.
label_dict (dict[str, Any]): Label dictionary.
N_cycle (int): Number of cycles.
mode (str): Mode of operation (e.g., 'inference').
inplace_safe (bool): Whether to use inplace operations safely. Defaults to True.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to 4.
N_model_seed (int): Number of model seeds. Defaults to 1.
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None.
Returns:
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries.
"""
pred_dicts = []
log_dicts = []
time_trackers = []
# use ones as
N_sample = self.configs.sample_diffusion["N_sample"]
label_dict = {}
label_dict['watermark']=torch.ones(N_sample, 1).to(input_feature_dict["restype"].device)
for _ in range(N_model_seed):
pred_dict, log_dict, time_tracker = self._main_inference_loop(
input_feature_dict=input_feature_dict,
label_dict=label_dict,
N_cycle=N_cycle,
mode=mode,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
symmetric_permutation=symmetric_permutation,
watermark=watermark
)
pred_dicts.append(pred_dict)
log_dicts.append(log_dict)
time_trackers.append(time_tracker)
# Combine outputs of multiple models
def _cat(dict_list, key):
return torch.cat([x[key] for x in dict_list], dim=0)
def _list_join(dict_list, key):
return sum([x[key] for x in dict_list], [])
all_pred_dict = {
"coordinate": _cat(pred_dicts, "coordinate"),
"summary_confidence": _list_join(pred_dicts, "summary_confidence"),
"full_data": _list_join(pred_dicts, "full_data"),
"plddt": _cat(pred_dicts, "plddt"),
"pae": _cat(pred_dicts, "pae"),
"pde": _cat(pred_dicts, "pde"),
"resolved": _cat(pred_dicts, "resolved"),
#"watermark": _cat(pred_dicts, "watermark")
}
if "coordinate_orig" in pred_dicts[0]:
all_pred_dict['coordinate_orig'] = _cat(pred_dicts, "coordinate_orig")
all_log_dict = simple_merge_dict_list(log_dicts)
all_time_dict = simple_merge_dict_list(time_trackers)
return all_pred_dict, label_dict, all_log_dict, all_time_dict
def _main_inference_loop(
self,
input_feature_dict: dict[str, Any],
label_dict: dict[str, Any],
N_cycle: int,
mode: str,
inplace_safe: bool = True,
chunk_size: Optional[int] = 4,
symmetric_permutation: SymmetricPermutation = None,
watermark=False
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Main inference loop (single model seed) for the Alphafold3 model.
Returns:
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries.
"""
step_st = time.time()
N_token = input_feature_dict["residue_index"].shape[-1]
if N_token <= 16:
deepspeed_evo_attention_condition_satisfy = False
else:
deepspeed_evo_attention_condition_satisfy = True
log_dict = {}
pred_dict = {}
time_tracker = {}
s_inputs, s, z = self.get_pairformer_output(
pairformer_stack= self.pairformer_stack,
input_feature_dict=input_feature_dict,
N_cycle=N_cycle,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
if mode == "inference":
keys_to_delete = []
for key in input_feature_dict.keys():
if "template_" in key or key in [
"msa",
"has_deletion",
"deletion_value",
"profile",
"deletion_mean",
"token_bonds",
]:
keys_to_delete.append(key)
for key in keys_to_delete:
del input_feature_dict[key]
torch.cuda.empty_cache()
step_trunk = time.time()
time_tracker.update({"pairformer": step_trunk - step_st})
# Sample diffusion
# [..., N_sample, N_atom, 3]
N_sample = self.configs.sample_diffusion["N_sample"]
N_step = self.configs.sample_diffusion["N_step"]
noise_schedule = self.inference_noise_scheduler(
N_step=N_step, device=s_inputs.device, dtype=s_inputs.dtype
)
pred_dict["coordinate"] = self.sample_diffusion(
denoise_net=self.diffusion_module,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s,
z_trunk=z,
N_sample=N_sample,
noise_schedule=noise_schedule,
inplace_safe=inplace_safe,
)
step_diffusion = time.time()
time_tracker.update({"diffusion": step_diffusion - step_trunk})
if mode == "inference" and N_token > 2000:
torch.cuda.empty_cache()
# Distogram logits: log contact_probs only, to reduce the dimension
pred_dict["contact_probs"] = sample_confidence.compute_contact_prob(
distogram_logits=self.distogram_head(z),
**sample_confidence.get_bin_params(self.configs.loss.distogram),
) # [N_token, N_token]
# add watermark
if watermark:
x_denoised, x_noise_level = autocasting_disable_decorator(
self.configs.skip_amp.sample_diffusion_training
)(structure_predictor)(
coordinate=pred_dict["coordinate"],
denoise_net=self.diffusion_module_encoder,
label_dict=label_dict,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s,
z_trunk=z,
N_sample=N_sample,
diffusion_chunk_size=self.configs.diffusion_chunk_size,
)
pred_dict["coordinate_orig"] = pred_dict["coordinate"]
pred_dict["coordinate"] = x_denoised
# Confidence logits
(
pred_dict["plddt"],
pred_dict["pae"],
pred_dict["pde"],
pred_dict["resolved"],
) = self.run_confidence_head(
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s,
z_trunk=z,
pair_mask=None,
x_pred_coords=pred_dict["coordinate"],
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel,
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention
and deepspeed_evo_attention_condition_satisfy,
use_lma=self.configs.use_lma,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
step_confidence = time.time()
time_tracker.update({"confidence": step_confidence - step_diffusion})
time_tracker.update({"model_forward": time.time() - step_st})
# Permutation: when label is given, permute coordinates and other heads
if label_dict is not None and symmetric_permutation is not None:
pred_dict, log_dict = symmetric_permutation.permute_inference_pred_dict(
input_feature_dict=input_feature_dict,
pred_dict=pred_dict,
label_dict=label_dict,
permute_by_pocket=("pocket_mask" in label_dict)
and ("interested_ligand_mask" in label_dict),
)
last_step_seconds = step_confidence
time_tracker.update({"permutation": time.time() - last_step_seconds})
# Summary Confidence & Full Data
# Computed after coordinates and logits are permuted
if label_dict is None:
interested_atom_mask = None
else:
interested_atom_mask = label_dict.get("interested_ligand_mask", None)
pred_dict["summary_confidence"], pred_dict["full_data"] = (
sample_confidence.compute_full_data_and_summary(
configs=self.configs,
pae_logits=pred_dict["pae"],
plddt_logits=pred_dict["plddt"],
pde_logits=pred_dict["pde"],
contact_probs=pred_dict.get(
"per_sample_contact_probs", pred_dict["contact_probs"]
),
token_asym_id=input_feature_dict["asym_id"],
token_has_frame=input_feature_dict["has_frame"],
atom_coordinate=pred_dict["coordinate"],
atom_to_token_idx=input_feature_dict["atom_to_token_idx"],
atom_is_polymer=1 - input_feature_dict["is_ligand"],
N_recycle=N_cycle,
interested_atom_mask=interested_atom_mask,
return_full_data=True,
mol_id=(input_feature_dict["mol_id"] if mode != "inference" else None),
elements_one_hot=(
input_feature_dict["ref_element"] if mode != "inference" else None
),
)
)
return pred_dict, log_dict, time_tracker
def main_train_loop(
self,
input_feature_dict: dict[str, Any],
label_full_dict: dict[str, Any],
label_dict: dict,
N_cycle: int,
symmetric_permutation: SymmetricPermutation,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Main training loop for the Alphafold3 model.
Args:
input_feature_dict (dict[str, Any]): Input features dictionary.
label_full_dict (dict[str, Any]): Full label dictionary (uncropped).
label_dict (dict): Label dictionary (cropped).
N_cycle (int): Number of cycles.
symmetric_permutation (SymmetricPermutation): Symmetric permutation object.
inplace_safe (bool): Whether to use inplace operations safely. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
Prediction, updated label, and log dictionaries.
"""
N_token = input_feature_dict["residue_index"].shape[-1]
if N_token <= 16:
deepspeed_evo_attention_condition_satisfy = False
else:
deepspeed_evo_attention_condition_satisfy = True
s_inputs, s, z = self.get_pairformer_output(
input_feature_dict=input_feature_dict,
N_cycle=N_cycle,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
log_dict = {}
pred_dict = {}
# Mini-rollout: used for confidence and label permutation
with torch.no_grad():
# [..., 1, N_atom, 3]
N_sample_mini_rollout = self.configs.sample_diffusion[
"N_sample_mini_rollout"
] # =1
N_step_mini_rollout = self.configs.sample_diffusion["N_step_mini_rollout"]
coordinate_mini = self.sample_diffusion(
denoise_net=self.diffusion_module,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs.detach(),
s_trunk=s.detach(),
z_trunk=z.detach(),
N_sample=N_sample_mini_rollout,
noise_schedule=self.inference_noise_scheduler(
N_step=N_step_mini_rollout,
device=s_inputs.device,
dtype=s_inputs.dtype,
),
)
coordinate_mini.detach_()
pred_dict["coordinate_mini"] = coordinate_mini
# Permute ground truth to match mini-rollout prediction
label_dict, perm_log_dict = (
symmetric_permutation.permute_label_to_match_mini_rollout(
coordinate_mini,
input_feature_dict,
label_dict,
label_full_dict,
)
)
log_dict.update(perm_log_dict)
# Confidence: use mini-rollout prediction, and detach token embeddings
plddt_pred, pae_pred, pde_pred, resolved_pred = self.run_confidence_head(
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s,
z_trunk=z,
pair_mask=None,
x_pred_coords=coordinate_mini,
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel,
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention
and deepspeed_evo_attention_condition_satisfy,
use_lma=self.configs.use_lma,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
pred_dict.update(
{
"plddt": plddt_pred,
"pae": pae_pred,
"pde": pde_pred,
"resolved": resolved_pred,
}
)
if self.train_confidence_only:
# Skip diffusion loss and distogram loss. Return now.
return pred_dict, label_dict, log_dict
# Denoising: use permuted coords to generate noisy samples and perform denoising
# x_denoised: [..., N_sample, N_atom, 3]
# x_noise_level: [..., N_sample]
N_sample = self.diffusion_batch_size
_, x_denoised, x_noise_level = autocasting_disable_decorator(
self.configs.skip_amp.sample_diffusion_training
)(sample_diffusion_training)(
noise_sampler=self.train_noise_sampler,
denoise_net=self.diffusion_module,
label_dict=label_dict,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s,
z_trunk=z,
N_sample=N_sample,
diffusion_chunk_size=self.configs.diffusion_chunk_size,
)
pred_dict.update(
{
"distogram": self.distogram_head(z),
# [..., N_sample=48, N_atom, 3]: diffusion loss
"coordinate": x_denoised,
"noise_level": x_noise_level,
}
)
# Permute symmetric atom/chain in each sample to match true structure
# Note: currently chains cannot be permuted since label is cropped
pred_dict, perm_log_dict, _, _ = (
symmetric_permutation.permute_diffusion_sample_to_match_label(
input_feature_dict, pred_dict, label_dict, stage="train"
)
)
log_dict.update(perm_log_dict)
return pred_dict, label_dict, log_dict
def ED_train_loop(
self,
input_feature_dict: dict[str, Any],
label_full_dict: dict[str, Any],
label_dict: dict,
N_cycle: int,
symmetric_permutation: SymmetricPermutation,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Main training loop for the Alphafold3 model.
Args:
input_feature_dict (dict[str, Any]): Input features dictionary.
label_full_dict (dict[str, Any]): Full label dictionary (uncropped).
label_dict (dict): Label dictionary (cropped).
N_cycle (int): Number of cycles.
symmetric_permutation (SymmetricPermutation): Symmetric permutation object.
inplace_safe (bool): Whether to use inplace operations safely. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
Prediction, updated label, and log dictionaries.
"""
N_sample = self.diffusion_batch_size
N_token = input_feature_dict["residue_index"].shape[-1]
if N_token <= 16:
deepspeed_evo_attention_condition_satisfy = False
else:
deepspeed_evo_attention_condition_satisfy = True
with torch.no_grad():
s_inputs, s, z = self.get_pairformer_output(
pairformer_stack = self.pairformer_stack,
input_feature_dict=input_feature_dict,
N_cycle=N_cycle,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
log_dict = {}
pred_dict = {}
x_gt_augment = centre_random_augmentation(
x_input_coords=label_dict["coordinate"],
N_sample=N_sample,
mask=label_dict["coordinate_mask"],
centre_only=False,
).to(
label_dict["coordinate"].dtype
) # [..., N_sample, N_atom, 3]
label_dict['coordinate_augment']=x_gt_augment
x_denoised, x_noise_level = autocasting_disable_decorator(
self.configs.skip_amp.sample_diffusion_training
)(structure_predictor)(
coordinate=x_gt_augment,
denoise_net=self.diffusion_module_encoder,
label_dict=label_dict,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s,
z_trunk=z,
N_sample=N_sample,
diffusion_chunk_size=self.configs.diffusion_chunk_size,
)
pred_dict.update(
{
"distogram": self.distogram_head(z),
# [..., N_sample, N_atom, 3]: diffusion loss
"coordinate": x_denoised,
"noise_level": x_noise_level,
}
)
x_denoised, watermark_label = random_sample_watermark(x_denoised, x_gt_augment, N_sample)
label_dict["watermark"] = watermark_label[..., None]
s_inputs_clean, s_clean, z_clean = self.get_pairformer_output(
pairformer_stack=self.pairformer_stack_decoder,
input_feature_dict=input_feature_dict,
N_cycle=N_cycle,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
use_msa=False
)
_, a_token, x_noise_level = autocasting_disable_decorator(
self.configs.skip_amp.sample_diffusion_training
)(watermark_decoder)(
coordinate=x_denoised,
denoise_net=self.diffusion_module_decoder,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs_clean,
s_trunk=s_clean,
z_trunk=z_clean,
N_sample=N_sample,
diffusion_chunk_size=self.configs.diffusion_chunk_size,
)
scores = self.gating_layer(a_token)
weights = F.softmax(scores, dim=-2)
extracted = self.code_extractor(a_token)
watermark = (extracted * weights).sum(dim=-2)
#watermark = self.code_extractor(a_token).mean(dim=-2)
pred_dict.update(
{
"watermark": watermark,
}
)
return pred_dict, label_dict, log_dict
def forward(
self,
input_feature_dict: dict[str, Any],
label_full_dict: dict[str, Any],
label_dict: dict[str, Any],
mode: str = "inference",
current_step: Optional[int] = None,
symmetric_permutation: SymmetricPermutation = None,
detect: Optional[bool] = False,
watermark: Optional[bool] = False,
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Forward pass of the Alphafold3 model.
Args:
input_feature_dict (dict[str, Any]): Input features dictionary.
label_full_dict (dict[str, Any]): Full label dictionary (uncropped).
label_dict (dict[str, Any]): Label dictionary (cropped).
mode (str): Mode of operation ('train', 'inference', 'eval'). Defaults to 'inference'.
current_step (Optional[int]): Current training step. Defaults to None.
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None.
Returns:
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
Prediction, updated label, and log dictionaries.
"""
assert mode in ["train", "inference", "eval"]
inplace_safe = not (self.training or torch.is_grad_enabled())
chunk_size = self.configs.infer_setting.chunk_size if inplace_safe else None
if mode == "train":
nc_rng = np.random.RandomState(current_step)
N_cycle = nc_rng.randint(1, self.N_cycle + 1)
assert self.training
assert label_dict is not None
assert symmetric_permutation is not None
pred_dict, label_dict, log_dict = self.ED_train_loop(
input_feature_dict=input_feature_dict,
label_full_dict=label_full_dict,
label_dict=label_dict,
N_cycle=N_cycle,
symmetric_permutation=symmetric_permutation,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
elif mode == "inference":
if not detect:
pred_dict, label_dict, log_dict, time_tracker = self.main_inference_loop(
input_feature_dict=input_feature_dict,
label_dict=None,
N_cycle=self.N_cycle,
mode=mode,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
N_model_seed=self.N_model_seed,
symmetric_permutation=None,
watermark=watermark,
)
else:
pred_dict, label_dict, log_dict, time_tracker = self.main_detection_loop(
input_feature_dict=input_feature_dict,
label_dict=label_dict,
N_cycle=self.N_cycle,
mode=mode,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
N_model_seed=self.N_model_seed,
symmetric_permutation=None,
)
log_dict.update({"time": time_tracker})
elif mode == "eval":
if label_dict is not None:
assert (
label_dict["coordinate"].size()
== label_full_dict["coordinate"].size()
)
label_dict.update(label_full_dict)
pred_dict, log_dict, time_tracker = self.main_inference_loop(
input_feature_dict=input_feature_dict,
label_dict=label_dict,
N_cycle=self.N_cycle,
mode=mode,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
N_model_seed=self.N_model_seed,
symmetric_permutation=symmetric_permutation,
)
log_dict.update({"time": time_tracker})
return pred_dict, label_dict, log_dict