Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
import torch.nn as nn
from protenix.model.modules.embedders import FourierEmbedding, RelativePositionEncoding
from protenix.model.modules.primitives import LinearNoBias, Transition
from protenix.model.modules.transformer import (
AtomAttentionDecoder,
AtomAttentionEncoder,
DiffusionTransformer,
)
from protenix.model.utils import expand_at_dim
from protenix.openfold_local.model.primitives import LayerNorm
from protenix.openfold_local.utils.checkpointing import get_checkpoint_fn
class DiffusionConditioning(nn.Module):
"""
Implements Algorithm 21 in AF3
"""
def __init__(
self,
sigma_data: float = 16.0,
c_z: int = 128,
c_s: int = 384,
c_s_inputs: int = 449,
c_noise_embedding: int = 256,
) -> None:
"""
Args:
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0.
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128.
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384.
c_s_inputs (int, optional): input embedding dim from InputEmbedder. Defaults to 449.
c_noise_embedding (int, optional): noise embedding dim. Defaults to 256.
"""
super(DiffusionConditioning, self).__init__()
self.sigma_data = sigma_data
self.c_z = c_z
self.c_s = c_s
self.c_s_inputs = c_s_inputs
# Line1-Line3:
self.relpe = RelativePositionEncoding(c_z=c_z)
self.layernorm_z = LayerNorm(2 * self.c_z)
self.linear_no_bias_z = LinearNoBias(
in_features=2 * self.c_z, out_features=self.c_z
)
# Line3-Line5:
self.transition_z1 = Transition(c_in=self.c_z, n=2)
self.transition_z2 = Transition(c_in=self.c_z, n=2)
# Line6-Line7
self.layernorm_s = LayerNorm(self.c_s + self.c_s_inputs)
self.linear_no_bias_s = LinearNoBias(
in_features=self.c_s + self.c_s_inputs, out_features=self.c_s
)
# Line8-Line9
self.fourier_embedding = FourierEmbedding(c=c_noise_embedding)
self.layernorm_n = LayerNorm(c_noise_embedding)
self.linear_no_bias_n = LinearNoBias(
in_features=c_noise_embedding, out_features=self.c_s
)
# Line10-Line12
self.transition_s1 = Transition(c_in=self.c_s, n=2)
self.transition_s2 = Transition(c_in=self.c_s, n=2)
print(f"Diffusion Module has {self.sigma_data}")
def forward(
self,
t_hat_noise_level: torch.Tensor,
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
s_inputs: torch.Tensor,
s_trunk: torch.Tensor,
z_trunk: torch.Tensor,
inplace_safe: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
t_hat_noise_level (torch.Tensor): the noise level
[..., N_sample]
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder
[..., N_tokens, c_s_inputs]
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17)
[..., N_tokens, c_s]
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17)
[..., N_tokens, N_tokens, c_z]
inplace_safe (bool): Whether it is safe to use inplace operations.
Returns:
tuple[torch.Tensor, torch.Tensor]: embeddings s and z
- s (torch.Tensor): [..., N_sample, N_tokens, c_s]
- z (torch.Tensor): [..., N_tokens, N_tokens, c_z]
"""
# Pair conditioning
pair_z = torch.cat(
tensors=[z_trunk, self.relpe(input_feature_dict)], dim=-1
) # [..., N_tokens, N_tokens, 2*c_z]
pair_z = self.linear_no_bias_z(self.layernorm_z(pair_z))
if inplace_safe:
pair_z += self.transition_z1(pair_z)
pair_z += self.transition_z2(pair_z)
else:
pair_z = pair_z + self.transition_z1(pair_z)
pair_z = pair_z + self.transition_z2(pair_z)
# Single conditioning
single_s = torch.cat(
tensors=[s_trunk, s_inputs], dim=-1
) # [..., N_tokens, c_s + c_s_inputs]
single_s = self.linear_no_bias_s(self.layernorm_s(single_s))
noise_n = self.fourier_embedding(
t_hat_noise_level=torch.log(input=t_hat_noise_level / self.sigma_data) / 4
).to(
single_s.dtype
) # [..., N_sample, c_in]
single_s = single_s.unsqueeze(dim=-3) + self.linear_no_bias_n(
self.layernorm_n(noise_n)
).unsqueeze(
dim=-2
) # [..., N_sample, N_tokens, c_s]
if inplace_safe:
single_s += self.transition_s1(single_s)
single_s += self.transition_s2(single_s)
else:
single_s = single_s + self.transition_s1(single_s)
single_s = single_s + self.transition_s2(single_s)
if not self.training and pair_z.shape[-2] > 2000:
torch.cuda.empty_cache()
return single_s, pair_z
class DiffusionSchedule:
def __init__(
self,
sigma_data: float = 16.0,
s_max: float = 160.0,
s_min: float = 4e-4,
p: float = 7.0,
dt: float = 1 / 200,
p_mean: float = -1.2,
p_std: float = 1.5,
) -> None:
"""
Args:
sigma_data (float, optional): The standard deviation of the data. Defaults to 16.0.
s_max (float, optional): The maximum noise level. Defaults to 160.0.
s_min (float, optional): The minimum noise level. Defaults to 4e-4.
p (float, optional): The exponent for the noise schedule. Defaults to 7.0.
dt (float, optional): The time step size. Defaults to 1/200.
p_mean (float, optional): The mean of the log-normal distribution for noise level sampling. Defaults to -1.2.
p_std (float, optional): The standard deviation of the log-normal distribution for noise level sampling. Defaults to 1.5.
"""
self.sigma_data = sigma_data
self.s_max = s_max
self.s_min = s_min
self.p = p
self.dt = dt
self.p_mean = p_mean
self.p_std = p_std
# self.T
self.T = int(1 / dt) + 1 # 201
def get_train_noise_schedule(self) -> torch.Tensor:
return self.sigma_data * torch.exp(self.p_mean + self.p_std * torch.randn(1))
def get_inference_noise_schedule(self) -> torch.Tensor:
time_step_lists = torch.arange(start=0, end=1 + 1e-10, step=self.dt)
inference_noise_schedule = (
self.sigma_data
* (
self.s_max ** (1 / self.p)
+ time_step_lists
* (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
)
** self.p
)
return inference_noise_schedule
class DiffusionModule(nn.Module):
"""
Implements Algorithm 20 in AF3
"""
def __init__(
self,
sigma_data: float = 16.0,
c_atom: int = 128,
c_atompair: int = 16,
c_token: int = 768,
c_s: int = 384,
c_z: int = 128,
c_s_inputs: int = 449,
atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4},
transformer: dict[str, int] = {"n_blocks": 24, "n_heads": 16},
atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4},
blocks_per_ckpt: Optional[int] = None,
use_fine_grained_checkpoint: bool = False,
initialization: Optional[dict[str, Union[str, float, bool]]] = None,
) -> None:
"""
Args:
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0.
c_atom (int, optional): embedding dim for atom feature. Defaults to 128.
c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16.
c_token (int, optional): feature channel of token (single a). Defaults to 768.
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384.
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128.
c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449.
atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}.
transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}.
atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}.
blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint
Size of each chunk. A higher value corresponds to fewer
checkpoints, and trades memory for speed. If None, no checkpointing is performed.
use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2
only effective if blocks_per_ckpt is not None.
initialization: initialize the diffusion module according to initialization config.
"""
super(DiffusionModule, self).__init__()
self.sigma_data = sigma_data
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_token = c_token
self.c_s_inputs = c_s_inputs
self.c_s = c_s
self.c_z = c_z
# Grad checkpoint setting
self.blocks_per_ckpt = blocks_per_ckpt
self.use_fine_grained_checkpoint = use_fine_grained_checkpoint
self.diffusion_conditioning = DiffusionConditioning(
sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs
)
self.atom_attention_encoder = AtomAttentionEncoder(
**atom_encoder,
c_atom=c_atom,
c_atompair=c_atompair,
c_token=c_token,
has_coords=True,
c_s=c_s,
c_z=c_z,
blocks_per_ckpt=blocks_per_ckpt,
)
# Alg20: line4
self.layernorm_s = LayerNorm(c_s)
self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token)
self.diffusion_transformer = DiffusionTransformer(
**transformer,
c_a=c_token,
c_s=c_s,
c_z=c_z,
blocks_per_ckpt=blocks_per_ckpt,
)
self.layernorm_a = LayerNorm(c_token)
self.atom_attention_decoder = AtomAttentionDecoder(
**atom_decoder,
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
blocks_per_ckpt=blocks_per_ckpt,
)
self.init_parameters(initialization)
def init_parameters(self, initialization: dict):
"""
Initializes the parameters of the diffusion module according to the provided initialization configuration.
Args:
initialization (dict): A dictionary containing initialization settings.
"""
if initialization.get("zero_init_condition_transition", False):
self.diffusion_conditioning.transition_z1.zero_init()
self.diffusion_conditioning.transition_z2.zero_init()
self.diffusion_conditioning.transition_s1.zero_init()
self.diffusion_conditioning.transition_s2.zero_init()
self.atom_attention_encoder.linear_init(
zero_init_atom_encoder_residual_linear=initialization.get(
"zero_init_atom_encoder_residual_linear", False
),
he_normal_init_atom_encoder_small_mlp=initialization.get(
"he_normal_init_atom_encoder_small_mlp", False
),
he_normal_init_atom_encoder_output=initialization.get(
"he_normal_init_atom_encoder_output", False
),
)
if initialization.get("glorot_init_self_attention", False):
for (
block
) in (
self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks
):
block.attention_pair_bias.glorot_init()
for block in self.diffusion_transformer.blocks:
if initialization.get("zero_init_adaln", False):
block.attention_pair_bias.layernorm_a.zero_init()
block.conditioned_transition_block.adaln.zero_init()
if initialization.get("zero_init_residual_condition_transition", False):
nn.init.zeros_(
block.conditioned_transition_block.linear_nobias_b.weight
)
if initialization.get("zero_init_atom_decoder_linear", False):
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_a.weight)
if initialization.get("zero_init_dit_output", False):
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_out.weight)
def f_forward(
self,
r_noisy: torch.Tensor,
t_hat_noise_level: torch.Tensor,
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
s_inputs: torch.Tensor,
s_trunk: torch.Tensor,
z_trunk: torch.Tensor,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""The raw network to be trained.
As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)).
Here, c_noise(sigma) is computed in Conditioning module.
Args:
r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x)
[..., N_sample, N_atom, 3]
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t
[..., N_sample]
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder
[..., N_tokens, c_s_inputs]
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17)
[..., N_tokens, c_s]
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17)
[..., N_tokens, N_tokens, c_z]
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
torch.Tensor: coordinates update
[..., N_sample, N_atom, 3]
"""
N_sample = r_noisy.size(-3)
assert t_hat_noise_level.size(-1) == N_sample
blocks_per_ckpt = self.blocks_per_ckpt
if not torch.is_grad_enabled():
blocks_per_ckpt = None
# Conditioning, shared across difference samples
# Diffusion_conditioning consumes 7-8G when token num is 768,
# use checkpoint here if blocks_per_ckpt is not None.
if blocks_per_ckpt:
checkpoint_fn = get_checkpoint_fn()
s_single, z_pair = checkpoint_fn(
self.diffusion_conditioning,
t_hat_noise_level,
input_feature_dict,
s_inputs,
s_trunk,
z_trunk,
inplace_safe,
)
else:
s_single, z_pair = self.diffusion_conditioning(
t_hat_noise_level=t_hat_noise_level,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
inplace_safe=inplace_safe,
) # [..., N_sample, N_token, c_s], [..., N_token, N_token, c_z]
# Expand embeddings to match N_sample
s_trunk = expand_at_dim(
s_trunk, dim=-3, n=N_sample
) # [..., N_sample, N_token, c_s]
z_pair = expand_at_dim(
z_pair, dim=-4, n=N_sample
) # [..., N_sample, N_token, N_token, c_z]
# Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM
if blocks_per_ckpt and self.use_fine_grained_checkpoint:
checkpoint_fn = get_checkpoint_fn()
a_token, q_skip, c_skip, p_skip = checkpoint_fn(
self.atom_attention_encoder,
input_feature_dict,
r_noisy,
s_trunk,
z_pair,
inplace_safe,
chunk_size,
)
else:
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder(
input_feature_dict=input_feature_dict,
r_l=r_noisy,
s=s_trunk,
z=z_pair,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
# Full self-attention on token level.
if inplace_safe:
a_token += self.linear_no_bias_s(
self.layernorm_s(s_single)
) # [..., N_sample, N_token, c_token]
else:
a_token = a_token + self.linear_no_bias_s(
self.layernorm_s(s_single)
) # [..., N_sample, N_token, c_token]
a_token = self.diffusion_transformer(
a=a_token,
s=s_single,
z=z_pair,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
a_token = self.layernorm_a(a_token)
# Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM
if blocks_per_ckpt and self.use_fine_grained_checkpoint:
checkpoint_fn = get_checkpoint_fn()
r_update = checkpoint_fn(
self.atom_attention_decoder,
input_feature_dict,
a_token,
q_skip,
c_skip,
p_skip,
inplace_safe,
chunk_size,
)
else:
# Broadcast token activations to atoms and run Sequence-local Atom Attention
r_update = self.atom_attention_decoder(
input_feature_dict=input_feature_dict,
a=a_token,
q_skip=q_skip,
c_skip=c_skip,
p_skip=p_skip,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
return r_update
def forward(
self,
x_noisy: torch.Tensor,
t_hat_noise_level: torch.Tensor,
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
s_inputs: torch.Tensor,
s_trunk: torch.Tensor,
z_trunk: torch.Tensor,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""One step denoise: x_noisy, noise_level -> x_denoised
Args:
x_noisy (torch.Tensor): the noisy version of the input atom coords
[..., N_sample, N_atom,3]
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t
[..., N_sample]
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder
[..., N_tokens, c_s_inputs]
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17)
[..., N_tokens, c_s]
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17)
[..., N_tokens, N_tokens, c_z]
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
torch.Tensor: the denoised coordinates of x
[..., N_sample, N_atom,3]
"""
# Scale positions to dimensionless vectors with approximately unit variance
# As in EDM:
# r_noisy = (c_in * x_noisy)
# where c_in = 1 / sqrt(sigma_data^2 + sigma^2)
r_noisy = (
x_noisy
/ torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None]
)
# Compute the update given r_noisy (the scaled x_noisy)
# As in EDM:
# r_update = F(r_noisy, c_noise(sigma))
r_update = self.f_forward(
r_noisy=r_noisy,
t_hat_noise_level=t_hat_noise_level,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
# Rescale updates to positions and combine with input positions
# As in EDM:
# D = c_skip * x_noisy + c_out * r_update
# c_skip = sigma_data^2 / (sigma_data^2 + sigma^2)
# c_out = (sigma_data * sigma) / sqrt(sigma_data^2 + sigma^2)
# s_ratio = sigma / sigma_data
# c_skip = 1 / (1 + s_ratio^2)
# c_out = sigma / sqrt(1 + s_ratio^2)
s_ratio = (t_hat_noise_level / self.sigma_data)[..., None, None].to(
r_update.dtype
)
x_denoised = (
1 / (1 + s_ratio**2) * x_noisy
+ t_hat_noise_level[..., None, None] / torch.sqrt(1 + s_ratio**2) * r_update
).to(r_update.dtype)
return x_denoised
class Struct_encoder(nn.Module):
"""
Implements Algorithm 20 in AF3
"""
def __init__(
self,
sigma_data: float = 16.0,
c_atom: int = 128,
c_atompair: int = 16,
c_token: int = 768,
c_s: int = 384,
c_z: int = 128,
c_s_inputs: int = 449,
watermark: int = 32,
atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4},
transformer: dict[str, int] = {"n_blocks": 6, "n_heads": 16},
atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4},
blocks_per_ckpt: Optional[int] = None,
use_fine_grained_checkpoint: bool = False,
initialization: Optional[dict[str, Union[str, float, bool]]] = None,
) -> None:
"""
Args:
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0.
c_atom (int, optional): embedding dim for atom feature. Defaults to 128.
c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16.
c_token (int, optional): feature channel of token (single a). Defaults to 768.
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384.
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128.
c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449.
atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}.
transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}.
atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}.
blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint
Size of each chunk. A higher value corresponds to fewer
checkpoints, and trades memory for speed. If None, no checkpointing is performed.
use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2
only effective if blocks_per_ckpt is not None.
initialization: initialize the diffusion module according to initialization config.
"""
super(Struct_encoder, self).__init__()
self.sigma_data = sigma_data
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_token = c_token
self.c_s_inputs = c_s_inputs
self.c_s = c_s
self.c_z = c_z
self.watermark = watermark
# Grad checkpoint setting
self.blocks_per_ckpt = blocks_per_ckpt
self.use_fine_grained_checkpoint = use_fine_grained_checkpoint
self.diffusion_conditioning = DiffusionConditioning(
sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs
)
self.atom_attention_encoder = AtomAttentionEncoder(
**atom_encoder,
c_atom=c_atom,
c_atompair=c_atompair,
c_token=c_token,
has_coords=True,
c_s=c_s,
c_z=c_z,
blocks_per_ckpt=blocks_per_ckpt,
)
# Alg20: line4
self.layernorm_s = LayerNorm(c_s)
self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token)
self.diffusion_transformer = DiffusionTransformer(
**transformer,
c_a=c_token,
c_s=c_s,
c_z=c_z,
blocks_per_ckpt=blocks_per_ckpt,
)
self.layernorm_a = LayerNorm(c_token)
self.atom_attention_decoder = AtomAttentionDecoder(
**atom_decoder,
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
blocks_per_ckpt=blocks_per_ckpt,
)
self.init_parameters(initialization)
def init_parameters(self, initialization: dict):
"""
Initializes the parameters of the diffusion module according to the provided initialization configuration.
Args:
initialization (dict): A dictionary containing initialization settings.
"""
if initialization.get("zero_init_condition_transition", False):
self.diffusion_conditioning.transition_z1.zero_init()
self.diffusion_conditioning.transition_z2.zero_init()
self.diffusion_conditioning.transition_s1.zero_init()
self.diffusion_conditioning.transition_s2.zero_init()
self.atom_attention_encoder.linear_init(
zero_init_atom_encoder_residual_linear=initialization.get(
"zero_init_atom_encoder_residual_linear", False
),
he_normal_init_atom_encoder_small_mlp=initialization.get(
"he_normal_init_atom_encoder_small_mlp", False
),
he_normal_init_atom_encoder_output=initialization.get(
"he_normal_init_atom_encoder_output", False
),
)
if initialization.get("glorot_init_self_attention", False):
for (
block
) in (
self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks
):
block.attention_pair_bias.glorot_init()
for block in self.diffusion_transformer.blocks:
if initialization.get("zero_init_adaln", False):
block.attention_pair_bias.layernorm_a.zero_init()
block.conditioned_transition_block.adaln.zero_init()
if initialization.get("zero_init_residual_condition_transition", False):
nn.init.zeros_(
block.conditioned_transition_block.linear_nobias_b.weight
)
if initialization.get("zero_init_atom_decoder_linear", False):
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_a.weight)
if initialization.get("zero_init_dit_output", False):
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_out.weight)
def f_forward(
self,
r_noisy: torch.Tensor,
t_hat_noise_level: torch.Tensor,
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
s_inputs: torch.Tensor,
s_trunk: torch.Tensor,
z_trunk: torch.Tensor,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""The raw network to be trained.
As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)).
Here, c_noise(sigma) is computed in Conditioning module.
Args:
r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x)
[..., N_sample, N_atom, 3]
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t
[..., N_sample]
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder
[..., N_tokens, c_s_inputs]
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17)
[..., N_tokens, c_s]
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17)
[..., N_tokens, N_tokens, c_z]
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
torch.Tensor: coordinates update
[..., N_sample, N_atom, 3]
"""
N_sample = r_noisy.size(-3)
assert t_hat_noise_level.size(-1) == N_sample
blocks_per_ckpt = self.blocks_per_ckpt
if not torch.is_grad_enabled():
blocks_per_ckpt = None
# Conditioning, shared across difference samples
# Diffusion_conditioning consumes 7-8G when token num is 768,
# use checkpoint here if blocks_per_ckpt is not None.
if blocks_per_ckpt:
checkpoint_fn = get_checkpoint_fn()
s_single, z_pair = checkpoint_fn(
self.diffusion_conditioning,
t_hat_noise_level,
input_feature_dict,
s_inputs,
s_trunk,
z_trunk,
inplace_safe,
)
else:
s_single, z_pair = self.diffusion_conditioning(
t_hat_noise_level=t_hat_noise_level,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
inplace_safe=inplace_safe,
) # [..., N_sample, N_token, c_s], [..., N_token, N_token, c_z]
# Expand embeddings to match N_sample
s_trunk = expand_at_dim(
s_trunk, dim=-3, n=N_sample
) # [..., N_sample, N_token, c_s]
z_pair = expand_at_dim(
z_pair, dim=-4, n=N_sample
) # [..., N_sample, N_token, N_token, c_z]
# Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM
if blocks_per_ckpt and self.use_fine_grained_checkpoint:
checkpoint_fn = get_checkpoint_fn()
a_token, q_skip, c_skip, p_skip = checkpoint_fn(
self.atom_attention_encoder,
input_feature_dict,
r_noisy,
s_trunk,
z_pair,
inplace_safe,
chunk_size,
)
else:
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder(
input_feature_dict=input_feature_dict,
r_l=r_noisy,
s=s_trunk,
z=z_pair,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
# Full self-attention on token level.
if inplace_safe:
a_token += self.linear_no_bias_s(
self.layernorm_s(s_single)
)
else:
a_token = a_token + self.linear_no_bias_s(
self.layernorm_s(s_single)
)
a_token = self.diffusion_transformer(
a=a_token,
s=s_single,
z=z_pair,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
a_token = self.layernorm_a(a_token)
# Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM
if blocks_per_ckpt and self.use_fine_grained_checkpoint:
checkpoint_fn = get_checkpoint_fn()
r_update = checkpoint_fn(
self.atom_attention_decoder,
input_feature_dict,
a_token,
q_skip,
c_skip,
p_skip,
inplace_safe,
chunk_size,
)
else:
# Broadcast token activations to atoms and run Sequence-local Atom Attention
r_update = self.atom_attention_decoder(
input_feature_dict=input_feature_dict,
a=a_token,
q_skip=q_skip,
c_skip=c_skip,
p_skip=p_skip,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
return r_update
def forward(
self,
x_noisy: torch.Tensor,
t_hat_noise_level: torch.Tensor,
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
s_inputs: torch.Tensor,
s_trunk: torch.Tensor,
z_trunk: torch.Tensor,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""One step denoise: x_noisy, noise_level -> x_denoised
Args:
x_noisy (torch.Tensor): the noisy version of the input atom coords
[..., N_sample, N_atom,3]
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t
[..., N_sample]
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder
[..., N_tokens, c_s_inputs]
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17)
[..., N_tokens, c_s]
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17)
[..., N_tokens, N_tokens, c_z]
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
torch.Tensor: the denoised coordinates of x
[..., N_sample, N_atom,3]
"""
# Scale positions to dimensionless vectors with approximately unit variance
# As in EDM:
# r_noisy = (c_in * x_noisy)
# where c_in = 1 / sqrt(sigma_data^2 + sigma^2)
r_noisy = (
x_noisy
/ torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None]
)
# Compute the update given r_noisy (the scaled x_noisy)
# As in EDM:
# r_update = F(r_noisy, c_noise(sigma))
r_update = self.f_forward(
r_noisy=r_noisy,
t_hat_noise_level=t_hat_noise_level,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
s_ratio = (t_hat_noise_level / self.sigma_data)[..., None, None].to(
r_update.dtype
)
x_denoised = (
1 / (1 + s_ratio**2) * x_noisy
+ t_hat_noise_level[..., None, None] / torch.sqrt(1 + s_ratio**2) * r_update
).to(r_update.dtype)
return x_denoised
class Struct_decoder(nn.Module):
"""
Implements Algorithm 20 in AF3
"""
def __init__(
self,
sigma_data: float = 16.0,
c_atom: int = 128,
c_atompair: int = 16,
c_token: int = 768,
c_s: int = 384,
c_z: int = 128,
c_s_inputs: int = 449,
watermark: int = 32,
atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4},
transformer: dict[str, int] = {"n_blocks": 6, "n_heads": 16},
atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4},
blocks_per_ckpt: Optional[int] = None,
use_fine_grained_checkpoint: bool = False,
initialization: Optional[dict[str, Union[str, float, bool]]] = None,
) -> None:
"""
Args:
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0.
c_atom (int, optional): embedding dim for atom feature. Defaults to 128.
c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16.
c_token (int, optional): feature channel of token (single a). Defaults to 768.
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384.
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128.
c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449.
atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}.
transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}.
atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}.
blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint
Size of each chunk. A higher value corresponds to fewer
checkpoints, and trades memory for speed. If None, no checkpointing is performed.
use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2
only effective if blocks_per_ckpt is not None.
initialization: initialize the diffusion module according to initialization config.
"""
super(Struct_decoder, self).__init__()
self.sigma_data = sigma_data
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_token = c_token
self.c_s_inputs = c_s_inputs
self.c_s = c_s
self.c_z = c_z
# Grad checkpoint setting
self.blocks_per_ckpt = blocks_per_ckpt
self.use_fine_grained_checkpoint = use_fine_grained_checkpoint
self.diffusion_conditioning = DiffusionConditioning(
sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs
)
self.atom_attention_encoder = AtomAttentionEncoder(
**atom_encoder,
c_atom=c_atom,
c_atompair=c_atompair,
c_token=c_token,
has_coords=True,
c_s=c_s,
c_z=c_z,
blocks_per_ckpt=blocks_per_ckpt,
)
# Alg20: line4
self.layernorm_s = LayerNorm(c_s)
self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token)
self.diffusion_transformer = DiffusionTransformer(
**transformer,
c_a=c_token,
c_s=c_s,
c_z=c_z,
blocks_per_ckpt=blocks_per_ckpt,
)
self.layernorm_a = LayerNorm(c_token)
self.init_parameters(initialization)
def init_parameters(self, initialization: dict):
"""
Initializes the parameters of the diffusion module according to the provided initialization configuration.
Args:
initialization (dict): A dictionary containing initialization settings.
"""
if initialization.get("zero_init_condition_transition", False):
self.diffusion_conditioning.transition_z1.zero_init()
self.diffusion_conditioning.transition_z2.zero_init()
self.diffusion_conditioning.transition_s1.zero_init()
self.diffusion_conditioning.transition_s2.zero_init()
self.atom_attention_encoder.linear_init(
zero_init_atom_encoder_residual_linear=initialization.get(
"zero_init_atom_encoder_residual_linear", False
),
he_normal_init_atom_encoder_small_mlp=initialization.get(
"he_normal_init_atom_encoder_small_mlp", False
),
he_normal_init_atom_encoder_output=initialization.get(
"he_normal_init_atom_encoder_output", False
),
)
if initialization.get("glorot_init_self_attention", False):
for (
block
) in (
self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks
):
block.attention_pair_bias.glorot_init()
for block in self.diffusion_transformer.blocks:
if initialization.get("zero_init_adaln", False):
block.attention_pair_bias.layernorm_a.zero_init()
block.conditioned_transition_block.adaln.zero_init()
if initialization.get("zero_init_residual_condition_transition", False):
nn.init.zeros_(
block.conditioned_transition_block.linear_nobias_b.weight
)
def f_forward(
self,
r_noisy: torch.Tensor,
t_hat_noise_level: torch.Tensor,
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
s_inputs: torch.Tensor,
s_trunk: torch.Tensor,
z_trunk: torch.Tensor,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""The raw network to be trained.
As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)).
Here, c_noise(sigma) is computed in Conditioning module.
Args:
r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x)
[..., N_sample, N_atom, 3]
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t
[..., N_sample]
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder
[..., N_tokens, c_s_inputs]
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17)
[..., N_tokens, c_s]
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17)
[..., N_tokens, N_tokens, c_z]
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
torch.Tensor: coordinates update
[..., N_sample, N_atom, 3]
"""
N_sample = r_noisy.size(-3)
assert t_hat_noise_level.size(-1) == N_sample
blocks_per_ckpt = self.blocks_per_ckpt
if not torch.is_grad_enabled():
blocks_per_ckpt = None
# Conditioning, shared across difference samples
# Diffusion_conditioning consumes 7-8G when token num is 768,
# use checkpoint here if blocks_per_ckpt is not None.
if blocks_per_ckpt:
checkpoint_fn = get_checkpoint_fn()
s_single, z_pair = checkpoint_fn(
self.diffusion_conditioning,
t_hat_noise_level,
input_feature_dict,
s_inputs,
s_trunk,
z_trunk,
inplace_safe,
)
else:
s_single, z_pair = self.diffusion_conditioning(
t_hat_noise_level=t_hat_noise_level,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
inplace_safe=inplace_safe,
) # [..., N_sample, N_token, c_s], [..., N_token, N_token, c_z]
# Expand embeddings to match N_sample
s_trunk = expand_at_dim(
s_trunk, dim=-3, n=N_sample
) # [..., N_sample, N_token, c_s]
z_pair = expand_at_dim(
z_pair, dim=-4, n=N_sample
) # [..., N_sample, N_token, N_token, c_z]
# Fine-grained checkpoint for finetuning stage 2 (token num: 768) for avoiding OOM
if blocks_per_ckpt and self.use_fine_grained_checkpoint:
checkpoint_fn = get_checkpoint_fn()
a_token, q_skip, c_skip, p_skip = checkpoint_fn(
self.atom_attention_encoder,
input_feature_dict,
r_noisy,
s_trunk,
z_pair,
inplace_safe,
chunk_size,
)
else:
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder(
input_feature_dict=input_feature_dict,
r_l=r_noisy,
s=s_trunk,
z=z_pair,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
# Full self-attention on token level.
if inplace_safe:
a_token += self.linear_no_bias_s(
self.layernorm_s(s_single)
) # [..., N_sample, N_token, c_token]
else:
a_token = a_token + self.linear_no_bias_s(
self.layernorm_s(s_single)
) # [..., N_sample, N_token, c_token]
a_token = self.diffusion_transformer(
a=a_token,
s=s_single,
z=z_pair,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
#a_token = self.layernorm_a(a_token)
return a_token
def forward(
self,
x_noisy: torch.Tensor,
t_hat_noise_level: torch.Tensor,
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
s_inputs: torch.Tensor,
s_trunk: torch.Tensor,
z_trunk: torch.Tensor,
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""One step denoise: x_noisy, noise_level -> x_denoised
Args:
x_noisy (torch.Tensor): the noisy version of the input atom coords
[..., N_sample, N_atom,3]
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t
[..., N_sample]
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder
[..., N_tokens, c_s_inputs]
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17)
[..., N_tokens, c_s]
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17)
[..., N_tokens, N_tokens, c_z]
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
torch.Tensor: the denoised coordinates of x
[..., N_sample, N_atom,3]
"""
# Scale positions to dimensionless vectors with approximately unit variance
# As in EDM:
# r_noisy = (c_in * x_noisy)
# where c_in = 1 / sqrt(sigma_data^2 + sigma^2)
r_noisy = (
x_noisy
/ torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None]
)
a_token = self.f_forward(
r_noisy=r_noisy,
t_hat_noise_level=t_hat_noise_level,
input_feature_dict=input_feature_dict,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
)
return a_token