|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from protenix.model.modules.embedders import FourierEmbedding, RelativePositionEncoding |
|
from protenix.model.modules.primitives import LinearNoBias, Transition |
|
from protenix.model.modules.transformer import ( |
|
AtomAttentionDecoder, |
|
AtomAttentionEncoder, |
|
DiffusionTransformer, |
|
) |
|
from protenix.model.utils import expand_at_dim |
|
from protenix.openfold_local.model.primitives import LayerNorm |
|
from protenix.openfold_local.utils.checkpointing import get_checkpoint_fn |
|
|
|
|
|
class DiffusionConditioning(nn.Module): |
|
""" |
|
Implements Algorithm 21 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sigma_data: float = 16.0, |
|
c_z: int = 128, |
|
c_s: int = 384, |
|
c_s_inputs: int = 449, |
|
c_noise_embedding: int = 256, |
|
) -> None: |
|
""" |
|
Args: |
|
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
c_s_inputs (int, optional): input embedding dim from InputEmbedder. Defaults to 449. |
|
c_noise_embedding (int, optional): noise embedding dim. Defaults to 256. |
|
""" |
|
super(DiffusionConditioning, self).__init__() |
|
self.sigma_data = sigma_data |
|
self.c_z = c_z |
|
self.c_s = c_s |
|
self.c_s_inputs = c_s_inputs |
|
|
|
self.relpe = RelativePositionEncoding(c_z=c_z) |
|
self.layernorm_z = LayerNorm(2 * self.c_z) |
|
self.linear_no_bias_z = LinearNoBias( |
|
in_features=2 * self.c_z, out_features=self.c_z |
|
) |
|
|
|
self.transition_z1 = Transition(c_in=self.c_z, n=2) |
|
self.transition_z2 = Transition(c_in=self.c_z, n=2) |
|
|
|
|
|
self.layernorm_s = LayerNorm(self.c_s + self.c_s_inputs) |
|
self.linear_no_bias_s = LinearNoBias( |
|
in_features=self.c_s + self.c_s_inputs, out_features=self.c_s |
|
) |
|
|
|
self.fourier_embedding = FourierEmbedding(c=c_noise_embedding) |
|
self.layernorm_n = LayerNorm(c_noise_embedding) |
|
self.linear_no_bias_n = LinearNoBias( |
|
in_features=c_noise_embedding, out_features=self.c_s |
|
) |
|
|
|
self.transition_s1 = Transition(c_in=self.c_s, n=2) |
|
self.transition_s2 = Transition(c_in=self.c_s, n=2) |
|
print(f"Diffusion Module has {self.sigma_data}") |
|
|
|
def forward( |
|
self, |
|
t_hat_noise_level: torch.Tensor, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
inplace_safe: bool = False, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
t_hat_noise_level (torch.Tensor): the noise level |
|
[..., N_sample] |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. |
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: embeddings s and z |
|
- s (torch.Tensor): [..., N_sample, N_tokens, c_s] |
|
- z (torch.Tensor): [..., N_tokens, N_tokens, c_z] |
|
""" |
|
|
|
pair_z = torch.cat( |
|
tensors=[z_trunk, self.relpe(input_feature_dict)], dim=-1 |
|
) |
|
pair_z = self.linear_no_bias_z(self.layernorm_z(pair_z)) |
|
if inplace_safe: |
|
pair_z += self.transition_z1(pair_z) |
|
pair_z += self.transition_z2(pair_z) |
|
else: |
|
pair_z = pair_z + self.transition_z1(pair_z) |
|
pair_z = pair_z + self.transition_z2(pair_z) |
|
|
|
single_s = torch.cat( |
|
tensors=[s_trunk, s_inputs], dim=-1 |
|
) |
|
single_s = self.linear_no_bias_s(self.layernorm_s(single_s)) |
|
noise_n = self.fourier_embedding( |
|
t_hat_noise_level=torch.log(input=t_hat_noise_level / self.sigma_data) / 4 |
|
).to( |
|
single_s.dtype |
|
) |
|
single_s = single_s.unsqueeze(dim=-3) + self.linear_no_bias_n( |
|
self.layernorm_n(noise_n) |
|
).unsqueeze( |
|
dim=-2 |
|
) |
|
if inplace_safe: |
|
single_s += self.transition_s1(single_s) |
|
single_s += self.transition_s2(single_s) |
|
else: |
|
single_s = single_s + self.transition_s1(single_s) |
|
single_s = single_s + self.transition_s2(single_s) |
|
if not self.training and pair_z.shape[-2] > 2000: |
|
torch.cuda.empty_cache() |
|
return single_s, pair_z |
|
|
|
|
|
class DiffusionSchedule: |
|
def __init__( |
|
self, |
|
sigma_data: float = 16.0, |
|
s_max: float = 160.0, |
|
s_min: float = 4e-4, |
|
p: float = 7.0, |
|
dt: float = 1 / 200, |
|
p_mean: float = -1.2, |
|
p_std: float = 1.5, |
|
) -> None: |
|
""" |
|
Args: |
|
sigma_data (float, optional): The standard deviation of the data. Defaults to 16.0. |
|
s_max (float, optional): The maximum noise level. Defaults to 160.0. |
|
s_min (float, optional): The minimum noise level. Defaults to 4e-4. |
|
p (float, optional): The exponent for the noise schedule. Defaults to 7.0. |
|
dt (float, optional): The time step size. Defaults to 1/200. |
|
p_mean (float, optional): The mean of the log-normal distribution for noise level sampling. Defaults to -1.2. |
|
p_std (float, optional): The standard deviation of the log-normal distribution for noise level sampling. Defaults to 1.5. |
|
""" |
|
self.sigma_data = sigma_data |
|
self.s_max = s_max |
|
self.s_min = s_min |
|
self.p = p |
|
self.dt = dt |
|
self.p_mean = p_mean |
|
self.p_std = p_std |
|
|
|
self.T = int(1 / dt) + 1 |
|
|
|
def get_train_noise_schedule(self) -> torch.Tensor: |
|
return self.sigma_data * torch.exp(self.p_mean + self.p_std * torch.randn(1)) |
|
|
|
def get_inference_noise_schedule(self) -> torch.Tensor: |
|
time_step_lists = torch.arange(start=0, end=1 + 1e-10, step=self.dt) |
|
inference_noise_schedule = ( |
|
self.sigma_data |
|
* ( |
|
self.s_max ** (1 / self.p) |
|
+ time_step_lists |
|
* (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p)) |
|
) |
|
** self.p |
|
) |
|
return inference_noise_schedule |
|
|
|
|
|
class DiffusionModule(nn.Module): |
|
""" |
|
Implements Algorithm 20 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sigma_data: float = 16.0, |
|
c_atom: int = 128, |
|
c_atompair: int = 16, |
|
c_token: int = 768, |
|
c_s: int = 384, |
|
c_z: int = 128, |
|
c_s_inputs: int = 449, |
|
atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, |
|
transformer: dict[str, int] = {"n_blocks": 24, "n_heads": 16}, |
|
atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, |
|
blocks_per_ckpt: Optional[int] = None, |
|
use_fine_grained_checkpoint: bool = False, |
|
initialization: Optional[dict[str, Union[str, float, bool]]] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. |
|
c_atom (int, optional): embedding dim for atom feature. Defaults to 128. |
|
c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16. |
|
c_token (int, optional): feature channel of token (single a). Defaults to 768. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449. |
|
atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}. |
|
transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}. |
|
atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}. |
|
blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing is performed. |
|
use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2 |
|
only effective if blocks_per_ckpt is not None. |
|
initialization: initialize the diffusion module according to initialization config. |
|
""" |
|
|
|
super(DiffusionModule, self).__init__() |
|
self.sigma_data = sigma_data |
|
self.c_atom = c_atom |
|
self.c_atompair = c_atompair |
|
self.c_token = c_token |
|
self.c_s_inputs = c_s_inputs |
|
self.c_s = c_s |
|
self.c_z = c_z |
|
|
|
|
|
self.blocks_per_ckpt = blocks_per_ckpt |
|
self.use_fine_grained_checkpoint = use_fine_grained_checkpoint |
|
|
|
self.diffusion_conditioning = DiffusionConditioning( |
|
sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs |
|
) |
|
self.atom_attention_encoder = AtomAttentionEncoder( |
|
**atom_encoder, |
|
c_atom=c_atom, |
|
c_atompair=c_atompair, |
|
c_token=c_token, |
|
has_coords=True, |
|
c_s=c_s, |
|
c_z=c_z, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
|
|
self.layernorm_s = LayerNorm(c_s) |
|
self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token) |
|
self.diffusion_transformer = DiffusionTransformer( |
|
**transformer, |
|
c_a=c_token, |
|
c_s=c_s, |
|
c_z=c_z, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.layernorm_a = LayerNorm(c_token) |
|
self.atom_attention_decoder = AtomAttentionDecoder( |
|
**atom_decoder, |
|
c_token=c_token, |
|
c_atom=c_atom, |
|
c_atompair=c_atompair, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.init_parameters(initialization) |
|
|
|
def init_parameters(self, initialization: dict): |
|
""" |
|
Initializes the parameters of the diffusion module according to the provided initialization configuration. |
|
|
|
Args: |
|
initialization (dict): A dictionary containing initialization settings. |
|
""" |
|
if initialization.get("zero_init_condition_transition", False): |
|
self.diffusion_conditioning.transition_z1.zero_init() |
|
self.diffusion_conditioning.transition_z2.zero_init() |
|
self.diffusion_conditioning.transition_s1.zero_init() |
|
self.diffusion_conditioning.transition_s2.zero_init() |
|
|
|
self.atom_attention_encoder.linear_init( |
|
zero_init_atom_encoder_residual_linear=initialization.get( |
|
"zero_init_atom_encoder_residual_linear", False |
|
), |
|
he_normal_init_atom_encoder_small_mlp=initialization.get( |
|
"he_normal_init_atom_encoder_small_mlp", False |
|
), |
|
he_normal_init_atom_encoder_output=initialization.get( |
|
"he_normal_init_atom_encoder_output", False |
|
), |
|
) |
|
|
|
if initialization.get("glorot_init_self_attention", False): |
|
for ( |
|
block |
|
) in ( |
|
self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks |
|
): |
|
block.attention_pair_bias.glorot_init() |
|
|
|
for block in self.diffusion_transformer.blocks: |
|
if initialization.get("zero_init_adaln", False): |
|
block.attention_pair_bias.layernorm_a.zero_init() |
|
block.conditioned_transition_block.adaln.zero_init() |
|
if initialization.get("zero_init_residual_condition_transition", False): |
|
nn.init.zeros_( |
|
block.conditioned_transition_block.linear_nobias_b.weight |
|
) |
|
|
|
if initialization.get("zero_init_atom_decoder_linear", False): |
|
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_a.weight) |
|
|
|
if initialization.get("zero_init_dit_output", False): |
|
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_out.weight) |
|
|
|
def f_forward( |
|
self, |
|
r_noisy: torch.Tensor, |
|
t_hat_noise_level: torch.Tensor, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""The raw network to be trained. |
|
As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)). |
|
Here, c_noise(sigma) is computed in Conditioning module. |
|
|
|
Args: |
|
r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x) |
|
[..., N_sample, N_atom, 3] |
|
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t |
|
[..., N_sample] |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: coordinates update |
|
[..., N_sample, N_atom, 3] |
|
""" |
|
N_sample = r_noisy.size(-3) |
|
assert t_hat_noise_level.size(-1) == N_sample |
|
|
|
blocks_per_ckpt = self.blocks_per_ckpt |
|
if not torch.is_grad_enabled(): |
|
blocks_per_ckpt = None |
|
|
|
|
|
|
|
if blocks_per_ckpt: |
|
checkpoint_fn = get_checkpoint_fn() |
|
s_single, z_pair = checkpoint_fn( |
|
self.diffusion_conditioning, |
|
t_hat_noise_level, |
|
input_feature_dict, |
|
s_inputs, |
|
s_trunk, |
|
z_trunk, |
|
inplace_safe, |
|
) |
|
else: |
|
s_single, z_pair = self.diffusion_conditioning( |
|
t_hat_noise_level=t_hat_noise_level, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
inplace_safe=inplace_safe, |
|
) |
|
|
|
|
|
s_trunk = expand_at_dim( |
|
s_trunk, dim=-3, n=N_sample |
|
) |
|
z_pair = expand_at_dim( |
|
z_pair, dim=-4, n=N_sample |
|
) |
|
|
|
if blocks_per_ckpt and self.use_fine_grained_checkpoint: |
|
checkpoint_fn = get_checkpoint_fn() |
|
a_token, q_skip, c_skip, p_skip = checkpoint_fn( |
|
self.atom_attention_encoder, |
|
input_feature_dict, |
|
r_noisy, |
|
s_trunk, |
|
z_pair, |
|
inplace_safe, |
|
chunk_size, |
|
) |
|
else: |
|
|
|
a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder( |
|
input_feature_dict=input_feature_dict, |
|
r_l=r_noisy, |
|
s=s_trunk, |
|
z=z_pair, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
if inplace_safe: |
|
a_token += self.linear_no_bias_s( |
|
self.layernorm_s(s_single) |
|
) |
|
else: |
|
a_token = a_token + self.linear_no_bias_s( |
|
self.layernorm_s(s_single) |
|
) |
|
a_token = self.diffusion_transformer( |
|
a=a_token, |
|
s=s_single, |
|
z=z_pair, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
a_token = self.layernorm_a(a_token) |
|
|
|
|
|
if blocks_per_ckpt and self.use_fine_grained_checkpoint: |
|
checkpoint_fn = get_checkpoint_fn() |
|
r_update = checkpoint_fn( |
|
self.atom_attention_decoder, |
|
input_feature_dict, |
|
a_token, |
|
q_skip, |
|
c_skip, |
|
p_skip, |
|
inplace_safe, |
|
chunk_size, |
|
) |
|
else: |
|
|
|
r_update = self.atom_attention_decoder( |
|
input_feature_dict=input_feature_dict, |
|
a=a_token, |
|
q_skip=q_skip, |
|
c_skip=c_skip, |
|
p_skip=p_skip, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
return r_update |
|
|
|
def forward( |
|
self, |
|
x_noisy: torch.Tensor, |
|
t_hat_noise_level: torch.Tensor, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""One step denoise: x_noisy, noise_level -> x_denoised |
|
|
|
Args: |
|
x_noisy (torch.Tensor): the noisy version of the input atom coords |
|
[..., N_sample, N_atom,3] |
|
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t |
|
[..., N_sample] |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the denoised coordinates of x |
|
[..., N_sample, N_atom,3] |
|
""" |
|
|
|
|
|
|
|
|
|
r_noisy = ( |
|
x_noisy |
|
/ torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None] |
|
) |
|
|
|
|
|
|
|
|
|
r_update = self.f_forward( |
|
r_noisy=r_noisy, |
|
t_hat_noise_level=t_hat_noise_level, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s_ratio = (t_hat_noise_level / self.sigma_data)[..., None, None].to( |
|
r_update.dtype |
|
) |
|
x_denoised = ( |
|
1 / (1 + s_ratio**2) * x_noisy |
|
+ t_hat_noise_level[..., None, None] / torch.sqrt(1 + s_ratio**2) * r_update |
|
).to(r_update.dtype) |
|
|
|
return x_denoised |
|
|
|
|
|
|
|
|
|
class Struct_encoder(nn.Module): |
|
""" |
|
Implements Algorithm 20 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sigma_data: float = 16.0, |
|
c_atom: int = 128, |
|
c_atompair: int = 16, |
|
c_token: int = 768, |
|
c_s: int = 384, |
|
c_z: int = 128, |
|
c_s_inputs: int = 449, |
|
watermark: int = 32, |
|
atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, |
|
transformer: dict[str, int] = {"n_blocks": 6, "n_heads": 16}, |
|
atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, |
|
blocks_per_ckpt: Optional[int] = None, |
|
use_fine_grained_checkpoint: bool = False, |
|
initialization: Optional[dict[str, Union[str, float, bool]]] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. |
|
c_atom (int, optional): embedding dim for atom feature. Defaults to 128. |
|
c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16. |
|
c_token (int, optional): feature channel of token (single a). Defaults to 768. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449. |
|
atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}. |
|
transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}. |
|
atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}. |
|
blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing is performed. |
|
use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2 |
|
only effective if blocks_per_ckpt is not None. |
|
initialization: initialize the diffusion module according to initialization config. |
|
""" |
|
|
|
super(Struct_encoder, self).__init__() |
|
self.sigma_data = sigma_data |
|
self.c_atom = c_atom |
|
self.c_atompair = c_atompair |
|
self.c_token = c_token |
|
self.c_s_inputs = c_s_inputs |
|
self.c_s = c_s |
|
self.c_z = c_z |
|
self.watermark = watermark |
|
|
|
|
|
self.blocks_per_ckpt = blocks_per_ckpt |
|
self.use_fine_grained_checkpoint = use_fine_grained_checkpoint |
|
|
|
self.diffusion_conditioning = DiffusionConditioning( |
|
sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs |
|
) |
|
self.atom_attention_encoder = AtomAttentionEncoder( |
|
**atom_encoder, |
|
c_atom=c_atom, |
|
c_atompair=c_atompair, |
|
c_token=c_token, |
|
has_coords=True, |
|
c_s=c_s, |
|
c_z=c_z, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
|
|
self.layernorm_s = LayerNorm(c_s) |
|
self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token) |
|
self.diffusion_transformer = DiffusionTransformer( |
|
**transformer, |
|
c_a=c_token, |
|
c_s=c_s, |
|
c_z=c_z, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.layernorm_a = LayerNorm(c_token) |
|
self.atom_attention_decoder = AtomAttentionDecoder( |
|
**atom_decoder, |
|
c_token=c_token, |
|
c_atom=c_atom, |
|
c_atompair=c_atompair, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.init_parameters(initialization) |
|
|
|
def init_parameters(self, initialization: dict): |
|
""" |
|
Initializes the parameters of the diffusion module according to the provided initialization configuration. |
|
|
|
Args: |
|
initialization (dict): A dictionary containing initialization settings. |
|
""" |
|
if initialization.get("zero_init_condition_transition", False): |
|
self.diffusion_conditioning.transition_z1.zero_init() |
|
self.diffusion_conditioning.transition_z2.zero_init() |
|
self.diffusion_conditioning.transition_s1.zero_init() |
|
self.diffusion_conditioning.transition_s2.zero_init() |
|
|
|
self.atom_attention_encoder.linear_init( |
|
zero_init_atom_encoder_residual_linear=initialization.get( |
|
"zero_init_atom_encoder_residual_linear", False |
|
), |
|
he_normal_init_atom_encoder_small_mlp=initialization.get( |
|
"he_normal_init_atom_encoder_small_mlp", False |
|
), |
|
he_normal_init_atom_encoder_output=initialization.get( |
|
"he_normal_init_atom_encoder_output", False |
|
), |
|
) |
|
|
|
if initialization.get("glorot_init_self_attention", False): |
|
for ( |
|
block |
|
) in ( |
|
self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks |
|
): |
|
block.attention_pair_bias.glorot_init() |
|
|
|
for block in self.diffusion_transformer.blocks: |
|
if initialization.get("zero_init_adaln", False): |
|
block.attention_pair_bias.layernorm_a.zero_init() |
|
block.conditioned_transition_block.adaln.zero_init() |
|
if initialization.get("zero_init_residual_condition_transition", False): |
|
nn.init.zeros_( |
|
block.conditioned_transition_block.linear_nobias_b.weight |
|
) |
|
|
|
if initialization.get("zero_init_atom_decoder_linear", False): |
|
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_a.weight) |
|
|
|
if initialization.get("zero_init_dit_output", False): |
|
nn.init.zeros_(self.atom_attention_decoder.linear_no_bias_out.weight) |
|
|
|
def f_forward( |
|
self, |
|
r_noisy: torch.Tensor, |
|
t_hat_noise_level: torch.Tensor, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""The raw network to be trained. |
|
As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)). |
|
Here, c_noise(sigma) is computed in Conditioning module. |
|
|
|
Args: |
|
r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x) |
|
[..., N_sample, N_atom, 3] |
|
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t |
|
[..., N_sample] |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: coordinates update |
|
[..., N_sample, N_atom, 3] |
|
""" |
|
N_sample = r_noisy.size(-3) |
|
assert t_hat_noise_level.size(-1) == N_sample |
|
|
|
blocks_per_ckpt = self.blocks_per_ckpt |
|
if not torch.is_grad_enabled(): |
|
blocks_per_ckpt = None |
|
|
|
|
|
|
|
if blocks_per_ckpt: |
|
checkpoint_fn = get_checkpoint_fn() |
|
s_single, z_pair = checkpoint_fn( |
|
self.diffusion_conditioning, |
|
t_hat_noise_level, |
|
input_feature_dict, |
|
s_inputs, |
|
s_trunk, |
|
z_trunk, |
|
inplace_safe, |
|
) |
|
else: |
|
s_single, z_pair = self.diffusion_conditioning( |
|
t_hat_noise_level=t_hat_noise_level, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
inplace_safe=inplace_safe, |
|
) |
|
|
|
|
|
s_trunk = expand_at_dim( |
|
s_trunk, dim=-3, n=N_sample |
|
) |
|
z_pair = expand_at_dim( |
|
z_pair, dim=-4, n=N_sample |
|
) |
|
|
|
if blocks_per_ckpt and self.use_fine_grained_checkpoint: |
|
checkpoint_fn = get_checkpoint_fn() |
|
a_token, q_skip, c_skip, p_skip = checkpoint_fn( |
|
self.atom_attention_encoder, |
|
input_feature_dict, |
|
r_noisy, |
|
s_trunk, |
|
z_pair, |
|
inplace_safe, |
|
chunk_size, |
|
) |
|
else: |
|
|
|
a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder( |
|
input_feature_dict=input_feature_dict, |
|
r_l=r_noisy, |
|
s=s_trunk, |
|
z=z_pair, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
if inplace_safe: |
|
a_token += self.linear_no_bias_s( |
|
self.layernorm_s(s_single) |
|
) |
|
else: |
|
a_token = a_token + self.linear_no_bias_s( |
|
self.layernorm_s(s_single) |
|
) |
|
a_token = self.diffusion_transformer( |
|
a=a_token, |
|
s=s_single, |
|
z=z_pair, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
a_token = self.layernorm_a(a_token) |
|
|
|
|
|
if blocks_per_ckpt and self.use_fine_grained_checkpoint: |
|
checkpoint_fn = get_checkpoint_fn() |
|
r_update = checkpoint_fn( |
|
self.atom_attention_decoder, |
|
input_feature_dict, |
|
a_token, |
|
q_skip, |
|
c_skip, |
|
p_skip, |
|
inplace_safe, |
|
chunk_size, |
|
) |
|
else: |
|
|
|
r_update = self.atom_attention_decoder( |
|
input_feature_dict=input_feature_dict, |
|
a=a_token, |
|
q_skip=q_skip, |
|
c_skip=c_skip, |
|
p_skip=p_skip, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
return r_update |
|
|
|
def forward( |
|
self, |
|
x_noisy: torch.Tensor, |
|
t_hat_noise_level: torch.Tensor, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""One step denoise: x_noisy, noise_level -> x_denoised |
|
|
|
Args: |
|
x_noisy (torch.Tensor): the noisy version of the input atom coords |
|
[..., N_sample, N_atom,3] |
|
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t |
|
[..., N_sample] |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the denoised coordinates of x |
|
[..., N_sample, N_atom,3] |
|
""" |
|
|
|
|
|
|
|
|
|
r_noisy = ( |
|
x_noisy |
|
/ torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None] |
|
) |
|
|
|
|
|
|
|
|
|
r_update = self.f_forward( |
|
r_noisy=r_noisy, |
|
t_hat_noise_level=t_hat_noise_level, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
s_ratio = (t_hat_noise_level / self.sigma_data)[..., None, None].to( |
|
r_update.dtype |
|
) |
|
x_denoised = ( |
|
1 / (1 + s_ratio**2) * x_noisy |
|
+ t_hat_noise_level[..., None, None] / torch.sqrt(1 + s_ratio**2) * r_update |
|
).to(r_update.dtype) |
|
|
|
return x_denoised |
|
|
|
|
|
class Struct_decoder(nn.Module): |
|
""" |
|
Implements Algorithm 20 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sigma_data: float = 16.0, |
|
c_atom: int = 128, |
|
c_atompair: int = 16, |
|
c_token: int = 768, |
|
c_s: int = 384, |
|
c_z: int = 128, |
|
c_s_inputs: int = 449, |
|
watermark: int = 32, |
|
atom_encoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, |
|
transformer: dict[str, int] = {"n_blocks": 6, "n_heads": 16}, |
|
atom_decoder: dict[str, int] = {"n_blocks": 3, "n_heads": 4}, |
|
blocks_per_ckpt: Optional[int] = None, |
|
use_fine_grained_checkpoint: bool = False, |
|
initialization: Optional[dict[str, Union[str, float, bool]]] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
sigma_data (torch.float, optional): the standard deviation of the data. Defaults to 16.0. |
|
c_atom (int, optional): embedding dim for atom feature. Defaults to 128. |
|
c_atompair (int, optional): embedding dim for atompair feature. Defaults to 16. |
|
c_token (int, optional): feature channel of token (single a). Defaults to 768. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s_inputs (int, optional): hidden dim [for single input embedding]. Defaults to 449. |
|
atom_encoder (dict[str, int], optional): configs in AtomAttentionEncoder. Defaults to {"n_blocks": 3, "n_heads": 4}. |
|
transformer (dict[str, int], optional): configs in DiffusionTransformer. Defaults to {"n_blocks": 24, "n_heads": 16}. |
|
atom_decoder (dict[str, int], optional): configs in AtomAttentionDecoder. Defaults to {"n_blocks": 3, "n_heads": 4}. |
|
blocks_per_ckpt: number of atom_encoder/transformer/atom_decoder blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing is performed. |
|
use_fine_grained_checkpoint: whether use fine-gained checkpoint for finetuning stage 2 |
|
only effective if blocks_per_ckpt is not None. |
|
initialization: initialize the diffusion module according to initialization config. |
|
""" |
|
|
|
super(Struct_decoder, self).__init__() |
|
self.sigma_data = sigma_data |
|
self.c_atom = c_atom |
|
self.c_atompair = c_atompair |
|
self.c_token = c_token |
|
self.c_s_inputs = c_s_inputs |
|
self.c_s = c_s |
|
self.c_z = c_z |
|
|
|
|
|
self.blocks_per_ckpt = blocks_per_ckpt |
|
self.use_fine_grained_checkpoint = use_fine_grained_checkpoint |
|
|
|
self.diffusion_conditioning = DiffusionConditioning( |
|
sigma_data=self.sigma_data, c_z=c_z, c_s=c_s, c_s_inputs=c_s_inputs |
|
) |
|
self.atom_attention_encoder = AtomAttentionEncoder( |
|
**atom_encoder, |
|
c_atom=c_atom, |
|
c_atompair=c_atompair, |
|
c_token=c_token, |
|
has_coords=True, |
|
c_s=c_s, |
|
c_z=c_z, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
|
|
self.layernorm_s = LayerNorm(c_s) |
|
self.linear_no_bias_s = LinearNoBias(in_features=c_s, out_features=c_token) |
|
self.diffusion_transformer = DiffusionTransformer( |
|
**transformer, |
|
c_a=c_token, |
|
c_s=c_s, |
|
c_z=c_z, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.layernorm_a = LayerNorm(c_token) |
|
self.init_parameters(initialization) |
|
|
|
def init_parameters(self, initialization: dict): |
|
""" |
|
Initializes the parameters of the diffusion module according to the provided initialization configuration. |
|
|
|
Args: |
|
initialization (dict): A dictionary containing initialization settings. |
|
""" |
|
if initialization.get("zero_init_condition_transition", False): |
|
self.diffusion_conditioning.transition_z1.zero_init() |
|
self.diffusion_conditioning.transition_z2.zero_init() |
|
self.diffusion_conditioning.transition_s1.zero_init() |
|
self.diffusion_conditioning.transition_s2.zero_init() |
|
|
|
self.atom_attention_encoder.linear_init( |
|
zero_init_atom_encoder_residual_linear=initialization.get( |
|
"zero_init_atom_encoder_residual_linear", False |
|
), |
|
he_normal_init_atom_encoder_small_mlp=initialization.get( |
|
"he_normal_init_atom_encoder_small_mlp", False |
|
), |
|
he_normal_init_atom_encoder_output=initialization.get( |
|
"he_normal_init_atom_encoder_output", False |
|
), |
|
) |
|
|
|
if initialization.get("glorot_init_self_attention", False): |
|
for ( |
|
block |
|
) in ( |
|
self.atom_attention_encoder.atom_transformer.diffusion_transformer.blocks |
|
): |
|
block.attention_pair_bias.glorot_init() |
|
|
|
for block in self.diffusion_transformer.blocks: |
|
if initialization.get("zero_init_adaln", False): |
|
block.attention_pair_bias.layernorm_a.zero_init() |
|
block.conditioned_transition_block.adaln.zero_init() |
|
if initialization.get("zero_init_residual_condition_transition", False): |
|
nn.init.zeros_( |
|
block.conditioned_transition_block.linear_nobias_b.weight |
|
) |
|
|
|
def f_forward( |
|
self, |
|
r_noisy: torch.Tensor, |
|
t_hat_noise_level: torch.Tensor, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""The raw network to be trained. |
|
As in EDM equation (7), this is F_theta(c_in * x, c_noise(sigma)). |
|
Here, c_noise(sigma) is computed in Conditioning module. |
|
|
|
Args: |
|
r_noisy (torch.Tensor): scaled x_noisy (i.e., c_in * x) |
|
[..., N_sample, N_atom, 3] |
|
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t |
|
[..., N_sample] |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input feature |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: coordinates update |
|
[..., N_sample, N_atom, 3] |
|
""" |
|
N_sample = r_noisy.size(-3) |
|
assert t_hat_noise_level.size(-1) == N_sample |
|
|
|
blocks_per_ckpt = self.blocks_per_ckpt |
|
if not torch.is_grad_enabled(): |
|
blocks_per_ckpt = None |
|
|
|
|
|
|
|
if blocks_per_ckpt: |
|
checkpoint_fn = get_checkpoint_fn() |
|
s_single, z_pair = checkpoint_fn( |
|
self.diffusion_conditioning, |
|
t_hat_noise_level, |
|
input_feature_dict, |
|
s_inputs, |
|
s_trunk, |
|
z_trunk, |
|
inplace_safe, |
|
) |
|
else: |
|
s_single, z_pair = self.diffusion_conditioning( |
|
t_hat_noise_level=t_hat_noise_level, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
inplace_safe=inplace_safe, |
|
) |
|
|
|
|
|
s_trunk = expand_at_dim( |
|
s_trunk, dim=-3, n=N_sample |
|
) |
|
z_pair = expand_at_dim( |
|
z_pair, dim=-4, n=N_sample |
|
) |
|
|
|
if blocks_per_ckpt and self.use_fine_grained_checkpoint: |
|
checkpoint_fn = get_checkpoint_fn() |
|
a_token, q_skip, c_skip, p_skip = checkpoint_fn( |
|
self.atom_attention_encoder, |
|
input_feature_dict, |
|
r_noisy, |
|
s_trunk, |
|
z_pair, |
|
inplace_safe, |
|
chunk_size, |
|
) |
|
else: |
|
|
|
a_token, q_skip, c_skip, p_skip = self.atom_attention_encoder( |
|
input_feature_dict=input_feature_dict, |
|
r_l=r_noisy, |
|
s=s_trunk, |
|
z=z_pair, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
if inplace_safe: |
|
a_token += self.linear_no_bias_s( |
|
self.layernorm_s(s_single) |
|
) |
|
else: |
|
a_token = a_token + self.linear_no_bias_s( |
|
self.layernorm_s(s_single) |
|
) |
|
a_token = self.diffusion_transformer( |
|
a=a_token, |
|
s=s_single, |
|
z=z_pair, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
|
|
|
|
return a_token |
|
|
|
def forward( |
|
self, |
|
x_noisy: torch.Tensor, |
|
t_hat_noise_level: torch.Tensor, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""One step denoise: x_noisy, noise_level -> x_denoised |
|
|
|
Args: |
|
x_noisy (torch.Tensor): the noisy version of the input atom coords |
|
[..., N_sample, N_atom,3] |
|
t_hat_noise_level (torch.Tensor): the noise level, as well as the time step t |
|
[..., N_sample] |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the denoised coordinates of x |
|
[..., N_sample, N_atom,3] |
|
""" |
|
|
|
|
|
|
|
|
|
r_noisy = ( |
|
x_noisy |
|
/ torch.sqrt(self.sigma_data**2 + t_hat_noise_level**2)[..., None, None] |
|
) |
|
|
|
a_token = self.f_forward( |
|
r_noisy=r_noisy, |
|
t_hat_noise_level=t_hat_noise_level, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
return a_token |
|
|