# Copyright 2024 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import torch import torch.nn as nn from protenix.model.modules.embedders import FourierEmbedding, RelativePositionEncoding from protenix.model.modules.primitives import LinearNoBias, Transition from protenix.model.modules.transformer import ( AtomAttentionDecoder, AtomAttentionEncoder, DiffusionTransformer, ) from protenix.model.utils import expand_at_dim from protenix.openfold_local.model.primitives import LayerNorm from protenix.openfold_local.utils.checkpointing import get_checkpoint_fn class DiffusionConditioning(nn.Module): """ Implements Algorithm 21 in AF3 """ def __init__( self, sigma_data: float = 16.0, c_z: int = 128, c_s: int = 384, c_s_inputs: int = 449, c_noise_embedding: int = 256, ) -> None: """ Args: sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. c_s_inputs (int, optional): input embedding dim from InputEmbedder. Defaults to 449. c_noise_embedding (int, optional): noise embedding dim. Defaults to 256. """ super(DiffusionConditioning, self).__init__() self.sigma_data = sigma_data self.c_z = c_z self.c_s = c_s self.c_s_inputs = c_s_inputs # Line1-Line3: self.relpe = RelativePositionEncoding(c_z=c_z) self.layernorm_z = LayerNorm(2 * self.c_z) self.linear_no_bias_z = LinearNoBias( in_features=2 * self.c_z, out_features=self.c_z ) # Line3-Line5: self.transition_z1 = Transition(c_in=self.c_z, n=2) self.transition_z2 = Transition(c_in=self.c_z, n=2) # Line6-Line7 self.layernorm_s = LayerNorm(self.c_s + self.c_s_inputs) self.linear_no_bias_s = LinearNoBias( in_features=self.c_s + self.c_s_inputs, out_features=self.c_s ) # Line8-Line9 self.fourier_embedding = FourierEmbedding(c=c_noise_embedding) self.layernorm_n = LayerNorm(c_noise_embedding) self.linear_no_bias_n = LinearNoBias( in_features=c_noise_embedding, out_features=self.c_s ) # Line10-Line12 self.transition_s1 = Transition(c_in=self.c_s, n=2) self.transition_s2 = Transition(c_in=self.c_s, n=2) print(f"Diffusion Module has {self.sigma_data}") def forward( self, t_hat_noise_level: torch.Tensor, input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, inplace_safe: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: t_hat_noise_level (torch.Tensor): the noise level [..., N_sample] input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] inplace_safe (bool): Whether it is safe to use inplace operations. Returns: tuple[torch.Tensor, torch.Tensor]: embeddings s and z - s (torch.Tensor): [..., N_sample, N_tokens, c_s] - z (torch.Tensor): [..., N_tokens, N_tokens, c_z] """ # Pair conditioning pair_z = torch.cat( tensors=[z_trunk, self.relpe(input_feature_dict)], dim=-1 ) # [..., N_tokens, N_tokens, 2*c_z] pair_z = self.linear_no_bias_z(self.layernorm_z(pair_z)) if inplace_safe: pair_z += self.transition_z1(pair_z) pair_z += self.transition_z2(pair_z) else: pair_z = pair_z + self.transition_z1(pair_z) pair_z = pair_z + self.transition_z2(pair_z) # Single conditioning single_s = torch.cat( tensors=[s_trunk, s_inputs], dim=-1 ) # [..., N_tokens, c_s + c_s_inputs] single_s = self.linear_no_bias_s(self.layernorm_s(single_s)) noise_n = self.fourier_embedding( t_hat_noise_level=torch.log(input=t_hat_noise_level / self.sigma_data) / 4 ).to( single_s.dtype ) # [..., N_sample, c_in] single_s = single_s.unsqueeze(dim=-3) + self.linear_no_bias_n( self.layernorm_n(noise_n) ).unsqueeze( dim=-2 ) # [..., N_sample, N_tokens, c_s] if inplace_safe: single_s += self.transition_s1(single_s) single_s += self.transition_s2(single_s) else: single_s = single_s + self.transition_s1(single_s) single_s = single_s + self.transition_s2(single_s) if not self.training and pair_z.shape[-2] > 2000: torch.cuda.empty_cache() return single_s, pair_z class DiffusionSchedule: def __init__( self, sigma_data: float = 16.0, s_max: float = 160.0, s_min: float = 4e-4, p: float = 7.0, dt: float = 1 / 200, p_mean: float = -1.2, p_std: float = 1.5, ) -> None: """ Args: sigma_data (float, optional): The standard deviation of the data. Defaults to 16.0. s_max (float, optional): The maximum noise level. Defaults to 160.0. s_min (float, optional): The minimum noise level. Defaults to 4e-4. p (float, optional): The exponent for the noise schedule. Defaults to 7.0. dt (float, optional): The time step size. Defaults to 1/200. p_mean (float, optional): The mean of the log-normal distribution for noise level sampling. Defaults to -1.2. p_std (float, optional): The standard deviation of the log-normal distribution for noise level sampling. Defaults to 1.5. """ self.sigma_data = sigma_data self.s_max = s_max self.s_min = s_min self.p = p self.dt = dt self.p_mean = p_mean self.p_std = p_std # self.T self.T = int(1 / dt) + 1 # 201 def get_train_noise_schedule(self) -> torch.Tensor: return self.sigma_data * torch.exp(self.p_mean + self.p_std * torch.randn(1)) def get_inference_noise_schedule(self) -> torch.Tensor: time_step_lists = torch.arange(start=0, end=1 + 1e-10, step=self.dt) inference_noise_schedule = ( self.sigma_data * ( self.s_max ** (1 / self.p) + time_step_lists * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p)) ) ** self.p ) return inference_noise_schedule class DiffusionModule(nn.Module): """ Implements Algorithm 20 in AF3 """ def __init__( self, sigma_data: float = 16.0, c_atom: int = 128, c_atompair: int = 16, c_token: int = 768, c_s: int = 384, c_z: int = 128, c_s_inputs: int = 449, atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, transformer: dict[str, int] = {"n_blocks": 24, "n_heads": 16}, atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, blocks_per_ckpt: Optional[int] = None, use_fine_grained_checkpoint: bool = False, initialization: Optional[dict[str, Union[str, float, bool]]] = None, ) -> None: """ Args: sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. c_atom (int, optional): embedding dim for atom feature. Defaults to 128. c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16. c_token (int, optional): feature channel of token (single a). Defaults to 768. c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449. atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}. transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}. atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}. blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint Size of each chunk. A higher value corresponds to fewer checkpoints, and trades memory for speed. If None, no checkpointing is performed. use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2 only effective if blocks_per_ckpt is not None. initialization: initialize the diffusion module according to initialization config. """ super(DiffusionModule, self).__init__() self.sigma_data = sigma_data self.c_atom = c_atom self.c_atompair = c_atompair self.c_token = c_token self.c_s_inputs = c_s_inputs self.c_s = c_s self.c_z = c_z # Grad checkpoint setting self.blocks_per_ckpt = blocks_per_ckpt self.use_fine_grained_checkpoint = use_fine_grained_checkpoint self.diffusion_conditioning = DiffusionConditioning( sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs ) self.atom_attention_encoder = AtomAttentionEncoder( **atom_encoder, c_atom=c_atom, c_atompair=c_atompair, c_token=c_token, has_coords=True, c_s=c_s, c_z=c_z, blocks_per_ckpt=blocks_per_ckpt, ) # Alg20: line4 self.layernorm_s = LayerNorm(c_s) self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token) self.diffusion_transformer = DiffusionTransformer( **transformer, c_a=c_token, c_s=c_s, c_z=c_z, blocks_per_ckpt=blocks_per_ckpt, ) self.layernorm_a = LayerNorm(c_token) self.atom_attention_decoder = AtomAttentionDecoder( **atom_decoder, c_token=c_token, c_atom=c_atom, c_atompair=c_atompair, blocks_per_ckpt=blocks_per_ckpt, ) self.init_parameters(initialization) def init_parameters(self, initialization: dict): """ Initializes the parameters of the diffusion module according to the provided initialization configuration. Args: initialization (dict): A dictionary containing initialization settings. """ if initialization.get("zero_init_condition_transition", False): self.diffusion_conditioning.transition_z1.zero_init() self.diffusion_conditioning.transition_z2.zero_init() self.diffusion_conditioning.transition_s1.zero_init() self.diffusion_conditioning.transition_s2.zero_init() self.atom_attention_encoder.linear_init( zero_init_atom_encoder_residual_linear=initialization.get( "zero_init_atom_encoder_residual_linear", False ), he_normal_init_atom_encoder_small_mlp=initialization.get( "he_normal_init_atom_encoder_small_mlp", False ), he_normal_init_atom_encoder_output=initialization.get( "he_normal_init_atom_encoder_output", False ), ) if initialization.get("glorot_init_self_attention", False): for ( block ) in ( self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks ): block.attention_pair_bias.glorot_init() for block in self.diffusion_transformer.blocks: if initialization.get("zero_init_adaln", False): block.attention_pair_bias.layernorm_a.zero_init() block.conditioned_transition_block.adaln.zero_init() if initialization.get("zero_init_residual_condition_transition", False): nn.init.zeros_( block.conditioned_transition_block.linear_nobias_b.weight ) if initialization.get("zero_init_atom_decoder_linear", False): nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_a.weight) if initialization.get("zero_init_dit_output", False): nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_out.weight) def f_forward( self, r_noisy: torch.Tensor, t_hat_noise_level: torch.Tensor, input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, inplace_safe: bool = False, chunk_size: Optional[int] = None, ) -> torch.Tensor: """The raw network to be trained. As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)). Here, c_noise(sigma) is computed in Conditioning module. Args: r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x) [..., N_sample, N_atom, 3] t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t [..., N_sample] input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. Returns: torch.Tensor: coordinates update [..., N_sample, N_atom, 3] """ N_sample = r_noisy.size(-3) assert t_hat_noise_level.size(-1) == N_sample blocks_per_ckpt = self.blocks_per_ckpt if not torch.is_grad_enabled(): blocks_per_ckpt = None # Conditioning, shared across difference samples # Diffusion_conditioning consumes 7-8G when token num is 768, # use checkpoint here if blocks_per_ckpt is not None. if blocks_per_ckpt: checkpoint_fn = get_checkpoint_fn() s_single, z_pair = checkpoint_fn( self.diffusion_conditioning, t_hat_noise_level, input_feature_dict, s_inputs, s_trunk, z_trunk, inplace_safe, ) else: s_single, z_pair = self.diffusion_conditioning( t_hat_noise_level=t_hat_noise_level, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, inplace_safe=inplace_safe, ) # [..., N_sample, N_token, c_s], [..., N_token, N_token, c_z] # Expand embeddings to match N_sample s_trunk = expand_at_dim( s_trunk, dim=-3, n=N_sample ) # [..., N_sample, N_token, c_s] z_pair = expand_at_dim( z_pair, dim=-4, n=N_sample ) # [..., N_sample, N_token, N_token, c_z] # Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM if blocks_per_ckpt and self.use_fine_grained_checkpoint: checkpoint_fn = get_checkpoint_fn() a_token, q_skip, c_skip, p_skip = checkpoint_fn( self.atom_attention_encoder, input_feature_dict, r_noisy, s_trunk, z_pair, inplace_safe, chunk_size, ) else: # Sequence-local Atom Attention and aggregation to coarse-grained tokens a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder( input_feature_dict=input_feature_dict, r_l=r_noisy, s=s_trunk, z=z_pair, inplace_safe=inplace_safe, chunk_size=chunk_size, ) # Full self-attention on token level. if inplace_safe: a_token += self.linear_no_bias_s( self.layernorm_s(s_single) ) # [..., N_sample, N_token, c_token] else: a_token = a_token + self.linear_no_bias_s( self.layernorm_s(s_single) ) # [..., N_sample, N_token, c_token] a_token = self.diffusion_transformer( a=a_token, s=s_single, z=z_pair, inplace_safe=inplace_safe, chunk_size=chunk_size, ) a_token = self.layernorm_a(a_token) # Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM if blocks_per_ckpt and self.use_fine_grained_checkpoint: checkpoint_fn = get_checkpoint_fn() r_update = checkpoint_fn( self.atom_attention_decoder, input_feature_dict, a_token, q_skip, c_skip, p_skip, inplace_safe, chunk_size, ) else: # Broadcast token activations to atoms and run Sequence-local Atom Attention r_update = self.atom_attention_decoder( input_feature_dict=input_feature_dict, a=a_token, q_skip=q_skip, c_skip=c_skip, p_skip=p_skip, inplace_safe=inplace_safe, chunk_size=chunk_size, ) return r_update def forward( self, x_noisy: torch.Tensor, t_hat_noise_level: torch.Tensor, input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, inplace_safe: bool = False, chunk_size: Optional[int] = None, ) -> torch.Tensor: """One step denoise: x_noisy, noise_level -> x_denoised Args: x_noisy (torch.Tensor): the noisy version of the input atom coords [..., N_sample, N_atom,3] t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t [..., N_sample] input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. Returns: torch.Tensor: the denoised coordinates of x [..., N_sample, N_atom,3] """ # Scale positions to dimensionless vectors with approximately unit variance # As in EDM: # r_noisy = (c_in * x_noisy) # where c_in = 1 / sqrt(sigma_data^2 + sigma^2) r_noisy = ( x_noisy / torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None] ) # Compute the update given r_noisy (the scaled x_noisy) # As in EDM: # r_update = F(r_noisy, c_noise(sigma)) r_update = self.f_forward( r_noisy=r_noisy, t_hat_noise_level=t_hat_noise_level, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, inplace_safe=inplace_safe, chunk_size=chunk_size, ) # Rescale updates to positions and combine with input positions # As in EDM: # D = c_skip * x_noisy + c_out * r_update # c_skip = sigma_data^2 / (sigma_data^2 + sigma^2) # c_out = (sigma_data * sigma) / sqrt(sigma_data^2 + sigma^2) # s_ratio = sigma / sigma_data # c_skip = 1 / (1 + s_ratio^2) # c_out = sigma / sqrt(1 + s_ratio^2) s_ratio = (t_hat_noise_level / self.sigma_data)[..., None, None].to( r_update.dtype ) x_denoised = ( 1 / (1 + s_ratio**2) * x_noisy + t_hat_noise_level[..., None, None] / torch.sqrt(1 + s_ratio**2) * r_update ).to(r_update.dtype) return x_denoised class Struct_encoder(nn.Module): """ Implements Algorithm 20 in AF3 """ def __init__( self, sigma_data: float = 16.0, c_atom: int = 128, c_atompair: int = 16, c_token: int = 768, c_s: int = 384, c_z: int = 128, c_s_inputs: int = 449, watermark: int = 32, atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, transformer: dict[str, int] = {"n_blocks": 6, "n_heads": 16}, atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, blocks_per_ckpt: Optional[int] = None, use_fine_grained_checkpoint: bool = False, initialization: Optional[dict[str, Union[str, float, bool]]] = None, ) -> None: """ Args: sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. c_atom (int, optional): embedding dim for atom feature. Defaults to 128. c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16. c_token (int, optional): feature channel of token (single a). Defaults to 768. c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449. atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}. transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}. atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}. blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint Size of each chunk. A higher value corresponds to fewer checkpoints, and trades memory for speed. If None, no checkpointing is performed. use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2 only effective if blocks_per_ckpt is not None. initialization: initialize the diffusion module according to initialization config. """ super(Struct_encoder, self).__init__() self.sigma_data = sigma_data self.c_atom = c_atom self.c_atompair = c_atompair self.c_token = c_token self.c_s_inputs = c_s_inputs self.c_s = c_s self.c_z = c_z self.watermark = watermark # Grad checkpoint setting self.blocks_per_ckpt = blocks_per_ckpt self.use_fine_grained_checkpoint = use_fine_grained_checkpoint self.diffusion_conditioning = DiffusionConditioning( sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs ) self.atom_attention_encoder = AtomAttentionEncoder( **atom_encoder, c_atom=c_atom, c_atompair=c_atompair, c_token=c_token, has_coords=True, c_s=c_s, c_z=c_z, blocks_per_ckpt=blocks_per_ckpt, ) # Alg20: line4 self.layernorm_s = LayerNorm(c_s) self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token) self.diffusion_transformer = DiffusionTransformer( **transformer, c_a=c_token, c_s=c_s, c_z=c_z, blocks_per_ckpt=blocks_per_ckpt, ) self.layernorm_a = LayerNorm(c_token) self.atom_attention_decoder = AtomAttentionDecoder( **atom_decoder, c_token=c_token, c_atom=c_atom, c_atompair=c_atompair, blocks_per_ckpt=blocks_per_ckpt, ) self.init_parameters(initialization) def init_parameters(self, initialization: dict): """ Initializes the parameters of the diffusion module according to the provided initialization configuration. Args: initialization (dict): A dictionary containing initialization settings. """ if initialization.get("zero_init_condition_transition", False): self.diffusion_conditioning.transition_z1.zero_init() self.diffusion_conditioning.transition_z2.zero_init() self.diffusion_conditioning.transition_s1.zero_init() self.diffusion_conditioning.transition_s2.zero_init() self.atom_attention_encoder.linear_init( zero_init_atom_encoder_residual_linear=initialization.get( "zero_init_atom_encoder_residual_linear", False ), he_normal_init_atom_encoder_small_mlp=initialization.get( "he_normal_init_atom_encoder_small_mlp", False ), he_normal_init_atom_encoder_output=initialization.get( "he_normal_init_atom_encoder_output", False ), ) if initialization.get("glorot_init_self_attention", False): for ( block ) in ( self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks ): block.attention_pair_bias.glorot_init() for block in self.diffusion_transformer.blocks: if initialization.get("zero_init_adaln", False): block.attention_pair_bias.layernorm_a.zero_init() block.conditioned_transition_block.adaln.zero_init() if initialization.get("zero_init_residual_condition_transition", False): nn.init.zeros_( block.conditioned_transition_block.linear_nobias_b.weight ) if initialization.get("zero_init_atom_decoder_linear", False): nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_a.weight) if initialization.get("zero_init_dit_output", False): nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_out.weight) def f_forward( self, r_noisy: torch.Tensor, t_hat_noise_level: torch.Tensor, input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, inplace_safe: bool = False, chunk_size: Optional[int] = None, ) -> torch.Tensor: """The raw network to be trained. As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)). Here, c_noise(sigma) is computed in Conditioning module. Args: r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x) [..., N_sample, N_atom, 3] t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t [..., N_sample] input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. Returns: torch.Tensor: coordinates update [..., N_sample, N_atom, 3] """ N_sample = r_noisy.size(-3) assert t_hat_noise_level.size(-1) == N_sample blocks_per_ckpt = self.blocks_per_ckpt if not torch.is_grad_enabled(): blocks_per_ckpt = None # Conditioning, shared across difference samples # Diffusion_conditioning consumes 7-8G when token num is 768, # use checkpoint here if blocks_per_ckpt is not None. if blocks_per_ckpt: checkpoint_fn = get_checkpoint_fn() s_single, z_pair = checkpoint_fn( self.diffusion_conditioning, t_hat_noise_level, input_feature_dict, s_inputs, s_trunk, z_trunk, inplace_safe, ) else: s_single, z_pair = self.diffusion_conditioning( t_hat_noise_level=t_hat_noise_level, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, inplace_safe=inplace_safe, ) # [..., N_sample, N_token, c_s], [..., N_token, N_token, c_z] # Expand embeddings to match N_sample s_trunk = expand_at_dim( s_trunk, dim=-3, n=N_sample ) # [..., N_sample, N_token, c_s] z_pair = expand_at_dim( z_pair, dim=-4, n=N_sample ) # [..., N_sample, N_token, N_token, c_z] # Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM if blocks_per_ckpt and self.use_fine_grained_checkpoint: checkpoint_fn = get_checkpoint_fn() a_token, q_skip, c_skip, p_skip = checkpoint_fn( self.atom_attention_encoder, input_feature_dict, r_noisy, s_trunk, z_pair, inplace_safe, chunk_size, ) else: # Sequence-local Atom Attention and aggregation to coarse-grained tokens a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder( input_feature_dict=input_feature_dict, r_l=r_noisy, s=s_trunk, z=z_pair, inplace_safe=inplace_safe, chunk_size=chunk_size, ) # Full self-attention on token level. if inplace_safe: a_token += self.linear_no_bias_s( self.layernorm_s(s_single) ) else: a_token = a_token + self.linear_no_bias_s( self.layernorm_s(s_single) ) a_token = self.diffusion_transformer( a=a_token, s=s_single, z=z_pair, inplace_safe=inplace_safe, chunk_size=chunk_size, ) a_token = self.layernorm_a(a_token) # Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM if blocks_per_ckpt and self.use_fine_grained_checkpoint: checkpoint_fn = get_checkpoint_fn() r_update = checkpoint_fn( self.atom_attention_decoder, input_feature_dict, a_token, q_skip, c_skip, p_skip, inplace_safe, chunk_size, ) else: # Broadcast token activations to atoms and run Sequence-local Atom Attention r_update = self.atom_attention_decoder( input_feature_dict=input_feature_dict, a=a_token, q_skip=q_skip, c_skip=c_skip, p_skip=p_skip, inplace_safe=inplace_safe, chunk_size=chunk_size, ) return r_update def forward( self, x_noisy: torch.Tensor, t_hat_noise_level: torch.Tensor, input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, inplace_safe: bool = False, chunk_size: Optional[int] = None, ) -> torch.Tensor: """One step denoise: x_noisy, noise_level -> x_denoised Args: x_noisy (torch.Tensor): the noisy version of the input atom coords [..., N_sample, N_atom,3] t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t [..., N_sample] input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. Returns: torch.Tensor: the denoised coordinates of x [..., N_sample, N_atom,3] """ # Scale positions to dimensionless vectors with approximately unit variance # As in EDM: # r_noisy = (c_in * x_noisy) # where c_in = 1 / sqrt(sigma_data^2 + sigma^2) r_noisy = ( x_noisy / torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None] ) # Compute the update given r_noisy (the scaled x_noisy) # As in EDM: # r_update = F(r_noisy, c_noise(sigma)) r_update = self.f_forward( r_noisy=r_noisy, t_hat_noise_level=t_hat_noise_level, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, inplace_safe=inplace_safe, chunk_size=chunk_size, ) s_ratio = (t_hat_noise_level / self.sigma_data)[..., None, None].to( r_update.dtype ) x_denoised = ( 1 / (1 + s_ratio**2) * x_noisy + t_hat_noise_level[..., None, None] / torch.sqrt(1 + s_ratio**2) * r_update ).to(r_update.dtype) return x_denoised class Struct_decoder(nn.Module): """ Implements Algorithm 20 in AF3 """ def __init__( self, sigma_data: float = 16.0, c_atom: int = 128, c_atompair: int = 16, c_token: int = 768, c_s: int = 384, c_z: int = 128, c_s_inputs: int = 449, watermark: int = 32, atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, transformer: dict[str, int] = {"n_blocks": 6, "n_heads": 16}, atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, blocks_per_ckpt: Optional[int] = None, use_fine_grained_checkpoint: bool = False, initialization: Optional[dict[str, Union[str, float, bool]]] = None, ) -> None: """ Args: sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. c_atom (int, optional): embedding dim for atom feature. Defaults to 128. c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16. c_token (int, optional): feature channel of token (single a). Defaults to 768. c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449. atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}. transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}. atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}. blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint Size of each chunk. A higher value corresponds to fewer checkpoints, and trades memory for speed. If None, no checkpointing is performed. use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2 only effective if blocks_per_ckpt is not None. initialization: initialize the diffusion module according to initialization config. """ super(Struct_decoder, self).__init__() self.sigma_data = sigma_data self.c_atom = c_atom self.c_atompair = c_atompair self.c_token = c_token self.c_s_inputs = c_s_inputs self.c_s = c_s self.c_z = c_z # Grad checkpoint setting self.blocks_per_ckpt = blocks_per_ckpt self.use_fine_grained_checkpoint = use_fine_grained_checkpoint self.diffusion_conditioning = DiffusionConditioning( sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs ) self.atom_attention_encoder = AtomAttentionEncoder( **atom_encoder, c_atom=c_atom, c_atompair=c_atompair, c_token=c_token, has_coords=True, c_s=c_s, c_z=c_z, blocks_per_ckpt=blocks_per_ckpt, ) # Alg20: line4 self.layernorm_s = LayerNorm(c_s) self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token) self.diffusion_transformer = DiffusionTransformer( **transformer, c_a=c_token, c_s=c_s, c_z=c_z, blocks_per_ckpt=blocks_per_ckpt, ) self.layernorm_a = LayerNorm(c_token) self.init_parameters(initialization) def init_parameters(self, initialization: dict): """ Initializes the parameters of the diffusion module according to the provided initialization configuration. Args: initialization (dict): A dictionary containing initialization settings. """ if initialization.get("zero_init_condition_transition", False): self.diffusion_conditioning.transition_z1.zero_init() self.diffusion_conditioning.transition_z2.zero_init() self.diffusion_conditioning.transition_s1.zero_init() self.diffusion_conditioning.transition_s2.zero_init() self.atom_attention_encoder.linear_init( zero_init_atom_encoder_residual_linear=initialization.get( "zero_init_atom_encoder_residual_linear", False ), he_normal_init_atom_encoder_small_mlp=initialization.get( "he_normal_init_atom_encoder_small_mlp", False ), he_normal_init_atom_encoder_output=initialization.get( "he_normal_init_atom_encoder_output", False ), ) if initialization.get("glorot_init_self_attention", False): for ( block ) in ( self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks ): block.attention_pair_bias.glorot_init() for block in self.diffusion_transformer.blocks: if initialization.get("zero_init_adaln", False): block.attention_pair_bias.layernorm_a.zero_init() block.conditioned_transition_block.adaln.zero_init() if initialization.get("zero_init_residual_condition_transition", False): nn.init.zeros_( block.conditioned_transition_block.linear_nobias_b.weight ) def f_forward( self, r_noisy: torch.Tensor, t_hat_noise_level: torch.Tensor, input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, inplace_safe: bool = False, chunk_size: Optional[int] = None, ) -> torch.Tensor: """The raw network to be trained. As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)). Here, c_noise(sigma) is computed in Conditioning module. Args: r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x) [..., N_sample, N_atom, 3] t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t [..., N_sample] input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. Returns: torch.Tensor: coordinates update [..., N_sample, N_atom, 3] """ N_sample = r_noisy.size(-3) assert t_hat_noise_level.size(-1) == N_sample blocks_per_ckpt = self.blocks_per_ckpt if not torch.is_grad_enabled(): blocks_per_ckpt = None # Conditioning, shared across difference samples # Diffusion_conditioning consumes 7-8G when token num is 768, # use checkpoint here if blocks_per_ckpt is not None. if blocks_per_ckpt: checkpoint_fn = get_checkpoint_fn() s_single, z_pair = checkpoint_fn( self.diffusion_conditioning, t_hat_noise_level, input_feature_dict, s_inputs, s_trunk, z_trunk, inplace_safe, ) else: s_single, z_pair = self.diffusion_conditioning( t_hat_noise_level=t_hat_noise_level, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, inplace_safe=inplace_safe, ) # [..., N_sample, N_token, c_s], [..., N_token, N_token, c_z] # Expand embeddings to match N_sample s_trunk = expand_at_dim( s_trunk, dim=-3, n=N_sample ) # [..., N_sample, N_token, c_s] z_pair = expand_at_dim( z_pair, dim=-4, n=N_sample ) # [..., N_sample, N_token, N_token, c_z] # Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM if blocks_per_ckpt and self.use_fine_grained_checkpoint: checkpoint_fn = get_checkpoint_fn() a_token, q_skip, c_skip, p_skip = checkpoint_fn( self.atom_attention_encoder, input_feature_dict, r_noisy, s_trunk, z_pair, inplace_safe, chunk_size, ) else: # Sequence-local Atom Attention and aggregation to coarse-grained tokens a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder( input_feature_dict=input_feature_dict, r_l=r_noisy, s=s_trunk, z=z_pair, inplace_safe=inplace_safe, chunk_size=chunk_size, ) # Full self-attention on token level. if inplace_safe: a_token += self.linear_no_bias_s( self.layernorm_s(s_single) ) # [..., N_sample, N_token, c_token] else: a_token = a_token + self.linear_no_bias_s( self.layernorm_s(s_single) ) # [..., N_sample, N_token, c_token] a_token = self.diffusion_transformer( a=a_token, s=s_single, z=z_pair, inplace_safe=inplace_safe, chunk_size=chunk_size, ) #a_token = self.layernorm_a(a_token) return a_token def forward( self, x_noisy: torch.Tensor, t_hat_noise_level: torch.Tensor, input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, inplace_safe: bool = False, chunk_size: Optional[int] = None, ) -> torch.Tensor: """One step denoise: x_noisy, noise_level -> x_denoised Args: x_noisy (torch.Tensor): the noisy version of the input atom coords [..., N_sample, N_atom,3] t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t [..., N_sample] input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. Returns: torch.Tensor: the denoised coordinates of x [..., N_sample, N_atom,3] """ # Scale positions to dimensionless vectors with approximately unit variance # As in EDM: # r_noisy = (c_in * x_noisy) # where c_in = 1 / sqrt(sigma_data^2 + sigma^2) r_noisy = ( x_noisy / torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None] ) a_token = self.f_forward( r_noisy=r_noisy, t_hat_noise_level=t_hat_noise_level, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, inplace_safe=inplace_safe, chunk_size=chunk_size, ) return a_token