|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
from typing import Any, Optional, Callable |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from protenix.model import sample_confidence |
|
from protenix.model.generator import ( |
|
InferenceNoiseScheduler, |
|
TrainingNoiseSampler, |
|
sample_diffusion, |
|
sample_diffusion_training, |
|
structure_predictor, |
|
watermark_decoder, |
|
) |
|
from protenix.model.utils import simple_merge_dict_list, random_sample_watermark, centre_random_augmentation |
|
from protenix.openfold_local.model.primitives import LayerNorm |
|
from protenix.utils.logger import get_logger |
|
from protenix.utils.permutation.permutation import SymmetricPermutation |
|
from protenix.utils.torch_utils import autocasting_disable_decorator |
|
|
|
from .modules.confidence import ConfidenceHead |
|
from .modules.diffusion import DiffusionModule, Struct_decoder, Struct_encoder |
|
from .modules.embedders import InputFeatureEmbedder, RelativePositionEncoding |
|
from .modules.head import DistogramHead |
|
from .modules.pairformer import MSAModule, PairformerStack, TemplateEmbedder |
|
from .modules.primitives import LinearNoBias |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class Protenix(nn.Module): |
|
""" |
|
Implements Algorithm 1 [Main Inference/Train Loop] in AF3 |
|
""" |
|
|
|
def __init__(self, configs) -> None: |
|
|
|
super(Protenix, self).__init__() |
|
self.configs = configs |
|
|
|
|
|
self.N_cycle = self.configs.model.N_cycle |
|
self.N_model_seed = self.configs.model.N_model_seed |
|
self.train_confidence_only = configs.train_confidence_only |
|
if self.train_confidence_only: |
|
assert configs.loss.weight.alpha_diffusion == 0.0 |
|
assert configs.loss.weight.alpha_distogram == 0.0 |
|
|
|
|
|
self.train_noise_sampler = TrainingNoiseSampler(**configs.train_noise_sampler) |
|
self.inference_noise_scheduler = InferenceNoiseScheduler( |
|
**configs.inference_noise_scheduler |
|
) |
|
self.diffusion_batch_size = self.configs.diffusion_batch_size |
|
|
|
|
|
self.input_embedder = InputFeatureEmbedder(**configs.model.input_embedder) |
|
self.relative_position_encoding = RelativePositionEncoding( |
|
**configs.model.relative_position_encoding |
|
) |
|
self.template_embedder = TemplateEmbedder(**configs.model.template_embedder) |
|
self.msa_module = MSAModule( |
|
**configs.model.msa_module, |
|
msa_configs=configs.data.get("msa", {}), |
|
) |
|
self.pairformer_stack = PairformerStack(**configs.model.pairformer) |
|
self.diffusion_module = DiffusionModule(**configs.model.diffusion_module) |
|
self.distogram_head = DistogramHead(**configs.model.distogram_head) |
|
self.confidence_head = ConfidenceHead(**configs.model.confidence_head) |
|
|
|
self.pairformer_stack_decoder = PairformerStack(**configs.model.pairformer_decoder) |
|
self.diffusion_module_encoder = Struct_encoder(**configs.model.diffusion_module_encoder_decoder) |
|
self.diffusion_module_decoder = Struct_decoder(**configs.model.diffusion_module_encoder_decoder) |
|
self.code_extractor = nn.Linear(configs.model.diffusion_module.c_token, 1) |
|
self.gating_layer = nn.Linear(configs.model.diffusion_module.c_token, 1) |
|
|
|
self.c_s, self.c_z, self.c_s_inputs, self.watermark = ( |
|
configs.c_s, |
|
configs.c_z, |
|
configs.c_s_inputs, |
|
configs.watermark, |
|
) |
|
self.linear_no_bias_sinit = LinearNoBias( |
|
in_features=self.c_s_inputs, out_features=self.c_s |
|
) |
|
self.linear_no_bias_zinit1 = LinearNoBias( |
|
in_features=self.c_s, out_features=self.c_z |
|
) |
|
self.linear_no_bias_zinit2 = LinearNoBias( |
|
in_features=self.c_s, out_features=self.c_z |
|
) |
|
self.linear_no_bias_token_bond = LinearNoBias( |
|
in_features=1, out_features=self.c_z |
|
) |
|
self.linear_no_bias_z_cycle = LinearNoBias( |
|
in_features=self.c_z, out_features=self.c_z |
|
) |
|
self.linear_no_bias_s = LinearNoBias( |
|
in_features=self.c_s, out_features=self.c_s |
|
) |
|
self.layernorm_z_cycle = LayerNorm(self.c_z) |
|
self.layernorm_s = LayerNorm(self.c_s) |
|
|
|
|
|
nn.init.zeros_(self.linear_no_bias_z_cycle.weight) |
|
nn.init.zeros_(self.linear_no_bias_s.weight) |
|
|
|
def get_pairformer_output( |
|
self, |
|
pairformer_stack: Callable, |
|
input_feature_dict: dict[str, Any], |
|
N_cycle: int, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
use_msa: Optional[bool] = True, |
|
) -> tuple[torch.Tensor, ...]: |
|
""" |
|
The forward pass from the input to pairformer output |
|
|
|
Args: |
|
input_feature_dict (dict[str, Any]): input features |
|
N_cycle (int): number of cycles |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, ...]: s_inputs, s, z |
|
""" |
|
N_token = input_feature_dict["residue_index"].shape[-1] |
|
if N_token <= 16: |
|
|
|
deepspeed_evo_attention_condition_satisfy = False |
|
else: |
|
deepspeed_evo_attention_condition_satisfy = True |
|
|
|
if self.train_confidence_only: |
|
self.input_embedder.eval() |
|
self.template_embedder.eval() |
|
self.msa_module.eval() |
|
self.pairformer_stack.eval() |
|
|
|
|
|
s_inputs = self.input_embedder( |
|
input_feature_dict, inplace_safe=False, chunk_size=chunk_size |
|
) |
|
s_init = self.linear_no_bias_sinit(s_inputs) |
|
z_init = ( |
|
self.linear_no_bias_zinit1(s_init)[..., None, :] |
|
+ self.linear_no_bias_zinit2(s_init)[..., None, :, :] |
|
) |
|
if inplace_safe: |
|
z_init += self.relative_position_encoding(input_feature_dict) |
|
z_init += self.linear_no_bias_token_bond( |
|
input_feature_dict["token_bonds"].unsqueeze(dim=-1) |
|
) |
|
else: |
|
z_init = z_init + self.relative_position_encoding(input_feature_dict) |
|
z_init = z_init + self.linear_no_bias_token_bond( |
|
input_feature_dict["token_bonds"].unsqueeze(dim=-1) |
|
) |
|
|
|
z = torch.zeros_like(z_init) |
|
s = torch.zeros_like(s_init) |
|
|
|
|
|
for cycle_no in range(N_cycle): |
|
with torch.set_grad_enabled( |
|
self.training |
|
and (not self.train_confidence_only) |
|
and cycle_no == (N_cycle - 1) |
|
): |
|
z = z_init + self.linear_no_bias_z_cycle(self.layernorm_z_cycle(z)) |
|
if inplace_safe: |
|
if self.template_embedder.n_blocks > 0: |
|
z += self.template_embedder( |
|
input_feature_dict, |
|
z, |
|
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
|
and deepspeed_evo_attention_condition_satisfy, |
|
use_lma=self.configs.use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
if use_msa: |
|
z = self.msa_module( |
|
input_feature_dict, |
|
z, |
|
s_inputs, |
|
pair_mask=None, |
|
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
|
and deepspeed_evo_attention_condition_satisfy, |
|
use_lma=self.configs.use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
else: |
|
if self.template_embedder.n_blocks > 0: |
|
z = z + self.template_embedder( |
|
input_feature_dict, |
|
z, |
|
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
|
and deepspeed_evo_attention_condition_satisfy, |
|
use_lma=self.configs.use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
if use_msa: |
|
z = self.msa_module( |
|
input_feature_dict, |
|
z, |
|
s_inputs, |
|
pair_mask=None, |
|
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
|
and deepspeed_evo_attention_condition_satisfy, |
|
use_lma=self.configs.use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
s = s_init + self.linear_no_bias_s(self.layernorm_s(s)) |
|
s, z = self.pairformer_stack( |
|
s, |
|
z, |
|
pair_mask=None, |
|
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
|
and deepspeed_evo_attention_condition_satisfy, |
|
use_lma=self.configs.use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
if self.train_confidence_only: |
|
self.input_embedder.train() |
|
self.template_embedder.train() |
|
self.msa_module.train() |
|
self.pairformer_stack.train() |
|
|
|
return s_inputs, s, z |
|
|
|
def sample_diffusion(self, **kwargs) -> torch.Tensor: |
|
""" |
|
Samples diffusion process based on the provided configurations. |
|
|
|
Returns: |
|
torch.Tensor: The result of the diffusion sampling process. |
|
""" |
|
_configs = { |
|
key: self.configs.sample_diffusion.get(key) |
|
for key in [ |
|
"gamma0", |
|
"gamma_min", |
|
"noise_scale_lambda", |
|
"step_scale_eta", |
|
] |
|
} |
|
_configs.update( |
|
{ |
|
"attn_chunk_size": ( |
|
self.configs.infer_setting.chunk_size if not self.training else None |
|
), |
|
"diffusion_chunk_size": ( |
|
self.configs.infer_setting.sample_diffusion_chunk_size |
|
if not self.training |
|
else None |
|
), |
|
} |
|
) |
|
return autocasting_disable_decorator(self.configs.skip_amp.sample_diffusion)( |
|
sample_diffusion |
|
)(**_configs, **kwargs) |
|
|
|
def run_confidence_head(self, *args, **kwargs): |
|
""" |
|
Runs the confidence head with optional automatic mixed precision (AMP) disabled. |
|
|
|
Returns: |
|
Any: The output of the confidence head. |
|
""" |
|
return autocasting_disable_decorator(self.configs.skip_amp.confidence_head)( |
|
self.confidence_head |
|
)(*args, **kwargs) |
|
|
|
def main_detection_loop( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
label_dict: dict[str, Any], |
|
N_cycle: int, |
|
mode: str, |
|
inplace_safe: bool = True, |
|
chunk_size: Optional[int] = 4, |
|
N_model_seed: int = 1, |
|
symmetric_permutation: SymmetricPermutation = None, |
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
""" |
|
Main inference loop (multiple model seeds) for the Alphafold3 model. |
|
|
|
Args: |
|
input_feature_dict (dict[str, Any]): Input features dictionary. |
|
label_dict (dict[str, Any]): Label dictionary. |
|
N_cycle (int): Number of cycles. |
|
mode (str): Mode of operation (e.g., 'inference'). |
|
inplace_safe (bool): Whether to use inplace operations safely. Defaults to True. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to 4. |
|
N_model_seed (int): Number of model seeds. Defaults to 1. |
|
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None. |
|
|
|
Returns: |
|
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
|
""" |
|
pred_dicts = [] |
|
log_dicts = [] |
|
time_trackers = [] |
|
|
|
N_sample = self.configs.sample_diffusion["N_sample"] |
|
|
|
|
|
for _ in range(N_model_seed): |
|
pred_dict, log_dict, time_tracker = self._main_detection_loop( |
|
input_feature_dict=input_feature_dict, |
|
label_dict=label_dict, |
|
N_cycle=N_cycle, |
|
mode=mode, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
symmetric_permutation=symmetric_permutation, |
|
) |
|
pred_dicts.append(pred_dict) |
|
log_dicts.append(log_dict) |
|
time_trackers.append(time_tracker) |
|
|
|
|
|
def _cat(dict_list, key): |
|
return torch.cat([x[key] for x in dict_list], dim=0) |
|
|
|
def _list_join(dict_list, key): |
|
return sum([x[key] for x in dict_list], []) |
|
|
|
all_pred_dict = { |
|
"watermark": _cat(pred_dicts, "watermark") |
|
} |
|
|
|
all_log_dict = simple_merge_dict_list(log_dicts) |
|
all_time_dict = simple_merge_dict_list(time_trackers) |
|
return all_pred_dict, label_dict, all_log_dict, all_time_dict |
|
|
|
|
|
def _main_detection_loop( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
label_dict: dict[str, Any], |
|
N_cycle: int, |
|
mode: str, |
|
inplace_safe: bool = True, |
|
chunk_size: Optional[int] = 4, |
|
symmetric_permutation: SymmetricPermutation = None, |
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
""" |
|
Main inference loop (single model seed) for the Alphafold3 model. |
|
|
|
Returns: |
|
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
|
""" |
|
step_st = time.time() |
|
N_token = input_feature_dict["residue_index"].shape[-1] |
|
if N_token <= 16: |
|
deepspeed_evo_attention_condition_satisfy = False |
|
else: |
|
deepspeed_evo_attention_condition_satisfy = True |
|
|
|
log_dict = {} |
|
pred_dict = {} |
|
time_tracker = {} |
|
|
|
s_inputs_clean, s_clean, z_clean = self.get_pairformer_output( |
|
pairformer_stack=self.pairformer_stack_decoder, |
|
input_feature_dict=input_feature_dict, |
|
N_cycle=N_cycle, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
use_msa=False, |
|
) |
|
|
|
if mode == "inference": |
|
keys_to_delete = [] |
|
for key in input_feature_dict.keys(): |
|
if "template_" in key or key in [ |
|
"msa", |
|
"has_deletion", |
|
"deletion_value", |
|
"profile", |
|
"deletion_mean", |
|
"token_bonds", |
|
]: |
|
keys_to_delete.append(key) |
|
|
|
for key in keys_to_delete: |
|
del input_feature_dict[key] |
|
torch.cuda.empty_cache() |
|
step_trunk = time.time() |
|
time_tracker.update({"pairformer": step_trunk - step_st}) |
|
|
|
|
|
|
|
|
|
_, a_token, x_noise_level = autocasting_disable_decorator( |
|
self.configs.skip_amp.sample_diffusion_training |
|
)(watermark_decoder)( |
|
coordinate=label_dict['coordinate'].unsqueeze(0), |
|
denoise_net=self.diffusion_module_decoder, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs_clean, |
|
s_trunk=s_clean, |
|
z_trunk=z_clean, |
|
N_sample=1, |
|
diffusion_chunk_size=self.configs.diffusion_chunk_size, |
|
) |
|
|
|
scores = self.gating_layer(a_token) |
|
weights = F.softmax(scores, dim=-2) |
|
extracted = self.code_extractor(a_token) |
|
watermark = (extracted * weights).sum(dim=-2) |
|
|
|
pred_dict.update( |
|
{ |
|
"watermark": watermark, |
|
} |
|
) |
|
|
|
time_tracker.update({"model_forward": time.time() - step_st}) |
|
|
|
|
|
|
|
return pred_dict, log_dict, time_tracker |
|
|
|
|
|
def main_inference_loop( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
label_dict: dict[str, Any], |
|
N_cycle: int, |
|
mode: str, |
|
inplace_safe: bool = True, |
|
chunk_size: Optional[int] = 4, |
|
N_model_seed: int = 1, |
|
symmetric_permutation: SymmetricPermutation = None, |
|
watermark=False |
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
""" |
|
Main inference loop (multiple model seeds) for the Alphafold3 model. |
|
|
|
Args: |
|
input_feature_dict (dict[str, Any]): Input features dictionary. |
|
label_dict (dict[str, Any]): Label dictionary. |
|
N_cycle (int): Number of cycles. |
|
mode (str): Mode of operation (e.g., 'inference'). |
|
inplace_safe (bool): Whether to use inplace operations safely. Defaults to True. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to 4. |
|
N_model_seed (int): Number of model seeds. Defaults to 1. |
|
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None. |
|
|
|
Returns: |
|
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
|
""" |
|
pred_dicts = [] |
|
log_dicts = [] |
|
time_trackers = [] |
|
|
|
N_sample = self.configs.sample_diffusion["N_sample"] |
|
label_dict = {} |
|
label_dict['watermark']=torch.ones(N_sample, 1).to(input_feature_dict["restype"].device) |
|
|
|
for _ in range(N_model_seed): |
|
pred_dict, log_dict, time_tracker = self._main_inference_loop( |
|
input_feature_dict=input_feature_dict, |
|
label_dict=label_dict, |
|
N_cycle=N_cycle, |
|
mode=mode, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
symmetric_permutation=symmetric_permutation, |
|
watermark=watermark |
|
) |
|
pred_dicts.append(pred_dict) |
|
log_dicts.append(log_dict) |
|
time_trackers.append(time_tracker) |
|
|
|
|
|
def _cat(dict_list, key): |
|
return torch.cat([x[key] for x in dict_list], dim=0) |
|
|
|
def _list_join(dict_list, key): |
|
return sum([x[key] for x in dict_list], []) |
|
|
|
all_pred_dict = { |
|
"coordinate": _cat(pred_dicts, "coordinate"), |
|
"summary_confidence": _list_join(pred_dicts, "summary_confidence"), |
|
"full_data": _list_join(pred_dicts, "full_data"), |
|
"plddt": _cat(pred_dicts, "plddt"), |
|
"pae": _cat(pred_dicts, "pae"), |
|
"pde": _cat(pred_dicts, "pde"), |
|
"resolved": _cat(pred_dicts, "resolved"), |
|
|
|
} |
|
|
|
if "coordinate_orig" in pred_dicts[0]: |
|
all_pred_dict['coordinate_orig'] = _cat(pred_dicts, "coordinate_orig") |
|
|
|
all_log_dict = simple_merge_dict_list(log_dicts) |
|
all_time_dict = simple_merge_dict_list(time_trackers) |
|
return all_pred_dict, label_dict, all_log_dict, all_time_dict |
|
|
|
def _main_inference_loop( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
label_dict: dict[str, Any], |
|
N_cycle: int, |
|
mode: str, |
|
inplace_safe: bool = True, |
|
chunk_size: Optional[int] = 4, |
|
symmetric_permutation: SymmetricPermutation = None, |
|
watermark=False |
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
""" |
|
Main inference loop (single model seed) for the Alphafold3 model. |
|
|
|
Returns: |
|
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: Prediction, log, and time dictionaries. |
|
""" |
|
step_st = time.time() |
|
N_token = input_feature_dict["residue_index"].shape[-1] |
|
if N_token <= 16: |
|
deepspeed_evo_attention_condition_satisfy = False |
|
else: |
|
deepspeed_evo_attention_condition_satisfy = True |
|
|
|
log_dict = {} |
|
pred_dict = {} |
|
time_tracker = {} |
|
|
|
s_inputs, s, z = self.get_pairformer_output( |
|
pairformer_stack= self.pairformer_stack, |
|
input_feature_dict=input_feature_dict, |
|
N_cycle=N_cycle, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
if mode == "inference": |
|
keys_to_delete = [] |
|
for key in input_feature_dict.keys(): |
|
if "template_" in key or key in [ |
|
"msa", |
|
"has_deletion", |
|
"deletion_value", |
|
"profile", |
|
"deletion_mean", |
|
"token_bonds", |
|
]: |
|
keys_to_delete.append(key) |
|
|
|
for key in keys_to_delete: |
|
del input_feature_dict[key] |
|
torch.cuda.empty_cache() |
|
step_trunk = time.time() |
|
time_tracker.update({"pairformer": step_trunk - step_st}) |
|
|
|
|
|
N_sample = self.configs.sample_diffusion["N_sample"] |
|
N_step = self.configs.sample_diffusion["N_step"] |
|
|
|
noise_schedule = self.inference_noise_scheduler( |
|
N_step=N_step, device=s_inputs.device, dtype=s_inputs.dtype |
|
) |
|
pred_dict["coordinate"] = self.sample_diffusion( |
|
denoise_net=self.diffusion_module, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s, |
|
z_trunk=z, |
|
N_sample=N_sample, |
|
noise_schedule=noise_schedule, |
|
inplace_safe=inplace_safe, |
|
) |
|
|
|
step_diffusion = time.time() |
|
time_tracker.update({"diffusion": step_diffusion - step_trunk}) |
|
if mode == "inference" and N_token > 2000: |
|
torch.cuda.empty_cache() |
|
|
|
pred_dict["contact_probs"] = sample_confidence.compute_contact_prob( |
|
distogram_logits=self.distogram_head(z), |
|
**sample_confidence.get_bin_params(self.configs.loss.distogram), |
|
) |
|
|
|
|
|
if watermark: |
|
x_denoised, x_noise_level = autocasting_disable_decorator( |
|
self.configs.skip_amp.sample_diffusion_training |
|
)(structure_predictor)( |
|
coordinate=pred_dict["coordinate"], |
|
denoise_net=self.diffusion_module_encoder, |
|
label_dict=label_dict, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s, |
|
z_trunk=z, |
|
N_sample=N_sample, |
|
diffusion_chunk_size=self.configs.diffusion_chunk_size, |
|
) |
|
pred_dict["coordinate_orig"] = pred_dict["coordinate"] |
|
pred_dict["coordinate"] = x_denoised |
|
|
|
|
|
( |
|
pred_dict["plddt"], |
|
pred_dict["pae"], |
|
pred_dict["pde"], |
|
pred_dict["resolved"], |
|
) = self.run_confidence_head( |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s, |
|
z_trunk=z, |
|
pair_mask=None, |
|
x_pred_coords=pred_dict["coordinate"], |
|
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
|
and deepspeed_evo_attention_condition_satisfy, |
|
use_lma=self.configs.use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
step_confidence = time.time() |
|
time_tracker.update({"confidence": step_confidence - step_diffusion}) |
|
time_tracker.update({"model_forward": time.time() - step_st}) |
|
|
|
|
|
if label_dict is not None and symmetric_permutation is not None: |
|
pred_dict, log_dict = symmetric_permutation.permute_inference_pred_dict( |
|
input_feature_dict=input_feature_dict, |
|
pred_dict=pred_dict, |
|
label_dict=label_dict, |
|
permute_by_pocket=("pocket_mask" in label_dict) |
|
and ("interested_ligand_mask" in label_dict), |
|
) |
|
last_step_seconds = step_confidence |
|
time_tracker.update({"permutation": time.time() - last_step_seconds}) |
|
|
|
|
|
|
|
if label_dict is None: |
|
interested_atom_mask = None |
|
else: |
|
interested_atom_mask = label_dict.get("interested_ligand_mask", None) |
|
pred_dict["summary_confidence"], pred_dict["full_data"] = ( |
|
sample_confidence.compute_full_data_and_summary( |
|
configs=self.configs, |
|
pae_logits=pred_dict["pae"], |
|
plddt_logits=pred_dict["plddt"], |
|
pde_logits=pred_dict["pde"], |
|
contact_probs=pred_dict.get( |
|
"per_sample_contact_probs", pred_dict["contact_probs"] |
|
), |
|
token_asym_id=input_feature_dict["asym_id"], |
|
token_has_frame=input_feature_dict["has_frame"], |
|
atom_coordinate=pred_dict["coordinate"], |
|
atom_to_token_idx=input_feature_dict["atom_to_token_idx"], |
|
atom_is_polymer=1 - input_feature_dict["is_ligand"], |
|
N_recycle=N_cycle, |
|
interested_atom_mask=interested_atom_mask, |
|
return_full_data=True, |
|
mol_id=(input_feature_dict["mol_id"] if mode != "inference" else None), |
|
elements_one_hot=( |
|
input_feature_dict["ref_element"] if mode != "inference" else None |
|
), |
|
) |
|
) |
|
|
|
|
|
return pred_dict, log_dict, time_tracker |
|
|
|
def main_train_loop( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
label_full_dict: dict[str, Any], |
|
label_dict: dict, |
|
N_cycle: int, |
|
symmetric_permutation: SymmetricPermutation, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
""" |
|
Main training loop for the Alphafold3 model. |
|
|
|
Args: |
|
input_feature_dict (dict[str, Any]): Input features dictionary. |
|
label_full_dict (dict[str, Any]): Full label dictionary (uncropped). |
|
label_dict (dict): Label dictionary (cropped). |
|
N_cycle (int): Number of cycles. |
|
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. |
|
inplace_safe (bool): Whether to use inplace operations safely. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
Prediction, updated label, and log dictionaries. |
|
""" |
|
N_token = input_feature_dict["residue_index"].shape[-1] |
|
if N_token <= 16: |
|
deepspeed_evo_attention_condition_satisfy = False |
|
else: |
|
deepspeed_evo_attention_condition_satisfy = True |
|
|
|
s_inputs, s, z = self.get_pairformer_output( |
|
input_feature_dict=input_feature_dict, |
|
N_cycle=N_cycle, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
log_dict = {} |
|
pred_dict = {} |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
N_sample_mini_rollout = self.configs.sample_diffusion[ |
|
"N_sample_mini_rollout" |
|
] |
|
N_step_mini_rollout = self.configs.sample_diffusion["N_step_mini_rollout"] |
|
|
|
coordinate_mini = self.sample_diffusion( |
|
denoise_net=self.diffusion_module, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs.detach(), |
|
s_trunk=s.detach(), |
|
z_trunk=z.detach(), |
|
N_sample=N_sample_mini_rollout, |
|
noise_schedule=self.inference_noise_scheduler( |
|
N_step=N_step_mini_rollout, |
|
device=s_inputs.device, |
|
dtype=s_inputs.dtype, |
|
), |
|
) |
|
coordinate_mini.detach_() |
|
pred_dict["coordinate_mini"] = coordinate_mini |
|
|
|
|
|
label_dict, perm_log_dict = ( |
|
symmetric_permutation.permute_label_to_match_mini_rollout( |
|
coordinate_mini, |
|
input_feature_dict, |
|
label_dict, |
|
label_full_dict, |
|
) |
|
) |
|
log_dict.update(perm_log_dict) |
|
|
|
|
|
plddt_pred, pae_pred, pde_pred, resolved_pred = self.run_confidence_head( |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s, |
|
z_trunk=z, |
|
pair_mask=None, |
|
x_pred_coords=coordinate_mini, |
|
use_memory_efficient_kernel=self.configs.use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=self.configs.use_deepspeed_evo_attention |
|
and deepspeed_evo_attention_condition_satisfy, |
|
use_lma=self.configs.use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
pred_dict.update( |
|
{ |
|
"plddt": plddt_pred, |
|
"pae": pae_pred, |
|
"pde": pde_pred, |
|
"resolved": resolved_pred, |
|
} |
|
) |
|
|
|
if self.train_confidence_only: |
|
|
|
return pred_dict, label_dict, log_dict |
|
|
|
|
|
|
|
|
|
N_sample = self.diffusion_batch_size |
|
_, x_denoised, x_noise_level = autocasting_disable_decorator( |
|
self.configs.skip_amp.sample_diffusion_training |
|
)(sample_diffusion_training)( |
|
noise_sampler=self.train_noise_sampler, |
|
denoise_net=self.diffusion_module, |
|
label_dict=label_dict, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s, |
|
z_trunk=z, |
|
N_sample=N_sample, |
|
diffusion_chunk_size=self.configs.diffusion_chunk_size, |
|
) |
|
pred_dict.update( |
|
{ |
|
"distogram": self.distogram_head(z), |
|
|
|
"coordinate": x_denoised, |
|
"noise_level": x_noise_level, |
|
} |
|
) |
|
|
|
|
|
|
|
pred_dict, perm_log_dict, _, _ = ( |
|
symmetric_permutation.permute_diffusion_sample_to_match_label( |
|
input_feature_dict, pred_dict, label_dict, stage="train" |
|
) |
|
) |
|
log_dict.update(perm_log_dict) |
|
|
|
return pred_dict, label_dict, log_dict |
|
|
|
def ED_train_loop( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
label_full_dict: dict[str, Any], |
|
label_dict: dict, |
|
N_cycle: int, |
|
symmetric_permutation: SymmetricPermutation, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
""" |
|
Main training loop for the Alphafold3 model. |
|
|
|
Args: |
|
input_feature_dict (dict[str, Any]): Input features dictionary. |
|
label_full_dict (dict[str, Any]): Full label dictionary (uncropped). |
|
label_dict (dict): Label dictionary (cropped). |
|
N_cycle (int): Number of cycles. |
|
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. |
|
inplace_safe (bool): Whether to use inplace operations safely. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
Prediction, updated label, and log dictionaries. |
|
""" |
|
|
|
N_sample = self.diffusion_batch_size |
|
|
|
N_token = input_feature_dict["residue_index"].shape[-1] |
|
|
|
if N_token <= 16: |
|
deepspeed_evo_attention_condition_satisfy = False |
|
else: |
|
deepspeed_evo_attention_condition_satisfy = True |
|
|
|
with torch.no_grad(): |
|
s_inputs, s, z = self.get_pairformer_output( |
|
pairformer_stack = self.pairformer_stack, |
|
input_feature_dict=input_feature_dict, |
|
N_cycle=N_cycle, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
log_dict = {} |
|
pred_dict = {} |
|
|
|
x_gt_augment = centre_random_augmentation( |
|
x_input_coords=label_dict["coordinate"], |
|
N_sample=N_sample, |
|
mask=label_dict["coordinate_mask"], |
|
centre_only=False, |
|
).to( |
|
label_dict["coordinate"].dtype |
|
) |
|
label_dict['coordinate_augment']=x_gt_augment |
|
|
|
x_denoised, x_noise_level = autocasting_disable_decorator( |
|
self.configs.skip_amp.sample_diffusion_training |
|
)(structure_predictor)( |
|
coordinate=x_gt_augment, |
|
denoise_net=self.diffusion_module_encoder, |
|
label_dict=label_dict, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s, |
|
z_trunk=z, |
|
N_sample=N_sample, |
|
diffusion_chunk_size=self.configs.diffusion_chunk_size, |
|
) |
|
|
|
pred_dict.update( |
|
{ |
|
"distogram": self.distogram_head(z), |
|
|
|
"coordinate": x_denoised, |
|
"noise_level": x_noise_level, |
|
} |
|
) |
|
|
|
x_denoised, watermark_label = random_sample_watermark(x_denoised, x_gt_augment, N_sample) |
|
label_dict["watermark"] = watermark_label[..., None] |
|
s_inputs_clean, s_clean, z_clean = self.get_pairformer_output( |
|
pairformer_stack=self.pairformer_stack_decoder, |
|
input_feature_dict=input_feature_dict, |
|
N_cycle=N_cycle, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
use_msa=False |
|
) |
|
|
|
_, a_token, x_noise_level = autocasting_disable_decorator( |
|
self.configs.skip_amp.sample_diffusion_training |
|
)(watermark_decoder)( |
|
coordinate=x_denoised, |
|
denoise_net=self.diffusion_module_decoder, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs_clean, |
|
s_trunk=s_clean, |
|
z_trunk=z_clean, |
|
N_sample=N_sample, |
|
diffusion_chunk_size=self.configs.diffusion_chunk_size, |
|
) |
|
|
|
scores = self.gating_layer(a_token) |
|
weights = F.softmax(scores, dim=-2) |
|
extracted = self.code_extractor(a_token) |
|
watermark = (extracted * weights).sum(dim=-2) |
|
|
|
|
|
pred_dict.update( |
|
{ |
|
"watermark": watermark, |
|
} |
|
) |
|
|
|
return pred_dict, label_dict, log_dict |
|
|
|
def forward( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
label_full_dict: dict[str, Any], |
|
label_dict: dict[str, Any], |
|
mode: str = "inference", |
|
current_step: Optional[int] = None, |
|
symmetric_permutation: SymmetricPermutation = None, |
|
detect: Optional[bool] = False, |
|
watermark: Optional[bool] = False, |
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
""" |
|
Forward pass of the Alphafold3 model. |
|
|
|
Args: |
|
input_feature_dict (dict[str, Any]): Input features dictionary. |
|
label_full_dict (dict[str, Any]): Full label dictionary (uncropped). |
|
label_dict (dict[str, Any]): Label dictionary (cropped). |
|
mode (str): Mode of operation ('train', 'inference', 'eval'). Defaults to 'inference'. |
|
current_step (Optional[int]): Current training step. Defaults to None. |
|
symmetric_permutation (SymmetricPermutation): Symmetric permutation object. Defaults to None. |
|
|
|
Returns: |
|
tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: |
|
Prediction, updated label, and log dictionaries. |
|
""" |
|
|
|
assert mode in ["train", "inference", "eval"] |
|
inplace_safe = not (self.training or torch.is_grad_enabled()) |
|
chunk_size = self.configs.infer_setting.chunk_size if inplace_safe else None |
|
|
|
if mode == "train": |
|
nc_rng = np.random.RandomState(current_step) |
|
N_cycle = nc_rng.randint(1, self.N_cycle + 1) |
|
assert self.training |
|
assert label_dict is not None |
|
assert symmetric_permutation is not None |
|
|
|
pred_dict, label_dict, log_dict = self.ED_train_loop( |
|
input_feature_dict=input_feature_dict, |
|
label_full_dict=label_full_dict, |
|
label_dict=label_dict, |
|
N_cycle=N_cycle, |
|
symmetric_permutation=symmetric_permutation, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
elif mode == "inference": |
|
if not detect: |
|
pred_dict, label_dict, log_dict, time_tracker = self.main_inference_loop( |
|
input_feature_dict=input_feature_dict, |
|
label_dict=None, |
|
N_cycle=self.N_cycle, |
|
mode=mode, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
N_model_seed=self.N_model_seed, |
|
symmetric_permutation=None, |
|
watermark=watermark, |
|
) |
|
else: |
|
pred_dict, label_dict, log_dict, time_tracker = self.main_detection_loop( |
|
input_feature_dict=input_feature_dict, |
|
label_dict=label_dict, |
|
N_cycle=self.N_cycle, |
|
mode=mode, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
N_model_seed=self.N_model_seed, |
|
symmetric_permutation=None, |
|
) |
|
log_dict.update({"time": time_tracker}) |
|
elif mode == "eval": |
|
if label_dict is not None: |
|
assert ( |
|
label_dict["coordinate"].size() |
|
== label_full_dict["coordinate"].size() |
|
) |
|
label_dict.update(label_full_dict) |
|
|
|
pred_dict, log_dict, time_tracker = self.main_inference_loop( |
|
input_feature_dict=input_feature_dict, |
|
label_dict=label_dict, |
|
N_cycle=self.N_cycle, |
|
mode=mode, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
N_model_seed=self.N_model_seed, |
|
symmetric_permutation=symmetric_permutation, |
|
) |
|
log_dict.update({"time": time_tracker}) |
|
|
|
return pred_dict, label_dict, log_dict |
|
|