|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Optional |
|
|
|
import torch |
|
|
|
from protenix.model.utils import centre_random_augmentation |
|
|
|
|
|
class TrainingNoiseSampler: |
|
""" |
|
Sample the noise-level of of training samples |
|
""" |
|
|
|
def __init__( |
|
self, |
|
p_mean: float = -1.2, |
|
p_std: float = 1.5, |
|
sigma_data: float = 16.0, |
|
) -> None: |
|
"""Sampler for training noise-level |
|
|
|
Args: |
|
p_mean (float, optional): gaussian mean. Defaults to -1.2. |
|
p_std (float, optional): gaussian std. Defaults to 1.5. |
|
sigma_data (float, optional): scale. Defaults to 16.0, but this is 1.0 in EDM. |
|
""" |
|
self.sigma_data = sigma_data |
|
self.p_mean = p_mean |
|
self.p_std = p_std |
|
print(f"train scheduler {self.sigma_data}") |
|
|
|
def __call__( |
|
self, size: torch.Size, device: torch.device = torch.device("cpu") |
|
) -> torch.Tensor: |
|
"""Sampling |
|
|
|
Args: |
|
size (torch.Size): the target size |
|
device (torch.device, optional): target device. Defaults to torch.device("cpu"). |
|
|
|
Returns: |
|
torch.Tensor: sampled noise-level |
|
""" |
|
rnd_normal = torch.randn(size=size, device=device) |
|
noise_level = (rnd_normal * self.p_std + self.p_mean).exp() * self.sigma_data |
|
return noise_level |
|
|
|
|
|
class InferenceNoiseScheduler: |
|
""" |
|
Scheduler for noise-level (time steps) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
s_max: float = 160.0, |
|
s_min: float = 4e-4, |
|
rho: float = 7, |
|
sigma_data: float = 16.0, |
|
) -> None: |
|
"""Scheduler parameters |
|
|
|
Args: |
|
s_max (float, optional): maximal noise level. Defaults to 160.0. |
|
s_min (float, optional): minimal noise level. Defaults to 4e-4. |
|
rho (float, optional): the exponent numerical part. Defaults to 7. |
|
sigma_data (float, optional): scale. Defaults to 16.0, but this is 1.0 in EDM. |
|
""" |
|
self.sigma_data = sigma_data |
|
self.s_max = s_max |
|
self.s_min = s_min |
|
self.rho = rho |
|
print(f"inference scheduler {self.sigma_data}") |
|
|
|
def __call__( |
|
self, |
|
N_step: int = 200, |
|
device: torch.device = torch.device("cpu"), |
|
dtype: torch.dtype = torch.float32, |
|
) -> torch.Tensor: |
|
"""Schedule the noise-level (time steps). No sampling is performed. |
|
|
|
Args: |
|
N_step (int, optional): number of time steps. Defaults to 200. |
|
device (torch.device, optional): target device. Defaults to torch.device("cpu"). |
|
dtype (torch.dtype, optional): target dtype. Defaults to torch.float32. |
|
|
|
Returns: |
|
torch.Tensor: noise-level (time_steps) |
|
[N_step+1] |
|
""" |
|
step_size = 1 / N_step |
|
step_indices = torch.arange(N_step + 1, device=device, dtype=dtype) |
|
t_step_list = ( |
|
self.sigma_data |
|
* ( |
|
self.s_max ** (1 / self.rho) |
|
+ step_indices |
|
* step_size |
|
* (self.s_min ** (1 / self.rho) - self.s_max ** (1 / self.rho)) |
|
) |
|
** self.rho |
|
) |
|
|
|
t_step_list[..., -1] = 0 |
|
|
|
return t_step_list |
|
|
|
|
|
def sample_diffusion( |
|
denoise_net: Callable, |
|
input_feature_dict: dict[str, Any], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
noise_schedule: torch.Tensor, |
|
N_sample: int = 1, |
|
gamma0: float = 0.8, |
|
gamma_min: float = 1.0, |
|
noise_scale_lambda: float = 1.003, |
|
step_scale_eta: float = 1.5, |
|
diffusion_chunk_size: Optional[int] = None, |
|
inplace_safe: bool = False, |
|
attn_chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""Implements Algorithm 18 in AF3. |
|
It performances denoising steps from time 0 to time T. |
|
The time steps (=noise levels) are given by noise_schedule. |
|
|
|
Args: |
|
denoise_net (Callable): the network that performs the denoising step. |
|
input_feature_dict (dict[str, Any]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
noise_schedule (torch.Tensor): noise-level schedule (which is also the time steps) since sigma=t. |
|
[N_iterations] |
|
N_sample (int): number of generated samples |
|
gamma0 (float): params in Alg.18. |
|
gamma_min (float): params in Alg.18. |
|
noise_scale_lambda (float): params in Alg.18. |
|
step_scale_eta (float): params in Alg.18. |
|
diffusion_chunk_size (Optional[int]): Chunk size for diffusion operation. Defaults to None. |
|
inplace_safe (bool): Whether to use inplace operations safely. Defaults to False. |
|
attn_chunk_size (Optional[int]): Chunk size for attention operation. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the denoised coordinates of x in inference stage |
|
[..., N_sample, N_atom, 3] |
|
""" |
|
N_atom = input_feature_dict["atom_to_token_idx"].size(-1) |
|
batch_shape = s_inputs.shape[:-2] |
|
device = s_inputs.device |
|
dtype = s_inputs.dtype |
|
|
|
def _chunk_sample_diffusion(chunk_n_sample, inplace_safe): |
|
|
|
|
|
x_l = noise_schedule[0] * torch.randn( |
|
size=(*batch_shape, chunk_n_sample, N_atom, 3), device=device, dtype=dtype |
|
) |
|
|
|
for _, (c_tau_last, c_tau) in enumerate( |
|
zip(noise_schedule[:-1], noise_schedule[1:]) |
|
): |
|
|
|
x_l = ( |
|
centre_random_augmentation(x_input_coords=x_l, N_sample=1) |
|
.squeeze(dim=-3) |
|
.to(dtype) |
|
) |
|
|
|
|
|
|
|
gamma = float(gamma0) if c_tau > gamma_min else 0 |
|
t_hat = c_tau_last * (gamma + 1) |
|
|
|
delta_noise_level = torch.sqrt(t_hat**2 - c_tau_last**2) |
|
x_noisy = x_l + noise_scale_lambda * delta_noise_level * torch.randn( |
|
size=x_l.shape, device=device, dtype=dtype |
|
) |
|
|
|
|
|
|
|
t_hat = ( |
|
t_hat.reshape((1,) * (len(batch_shape) + 1)) |
|
.expand(*batch_shape, chunk_n_sample) |
|
.to(dtype) |
|
) |
|
|
|
x_denoised = denoise_net( |
|
x_noisy=x_noisy, |
|
t_hat_noise_level=t_hat, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
chunk_size=attn_chunk_size, |
|
inplace_safe=inplace_safe, |
|
) |
|
|
|
delta = (x_noisy - x_denoised) / t_hat[ |
|
..., None, None |
|
] |
|
dt = c_tau - t_hat |
|
x_l = x_noisy + step_scale_eta * dt[..., None, None] * delta |
|
|
|
return x_l |
|
|
|
if diffusion_chunk_size is None: |
|
x_l = _chunk_sample_diffusion(N_sample, inplace_safe=inplace_safe) |
|
else: |
|
x_l = [] |
|
no_chunks = N_sample // diffusion_chunk_size + ( |
|
N_sample % diffusion_chunk_size != 0 |
|
) |
|
for i in range(no_chunks): |
|
chunk_n_sample = ( |
|
diffusion_chunk_size |
|
if i < no_chunks - 1 |
|
else N_sample - i * diffusion_chunk_size |
|
) |
|
chunk_x_l = _chunk_sample_diffusion( |
|
chunk_n_sample, inplace_safe=inplace_safe |
|
) |
|
x_l.append(chunk_x_l) |
|
x_l = torch.cat(x_l, -3) |
|
return x_l |
|
|
|
|
|
def sample_diffusion_training( |
|
noise_sampler: TrainingNoiseSampler, |
|
denoise_net: Callable, |
|
label_dict: dict[str, Any], |
|
input_feature_dict: dict[str, Any], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
N_sample: int = 1, |
|
diffusion_chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, ...]: |
|
"""Implements diffusion training as described in AF3 Appendix at page 23. |
|
It performances denoising steps from time 0 to time T. |
|
The time steps (=noise levels) are given by noise_schedule. |
|
|
|
Args: |
|
denoise_net (Callable): the network that performs the denoising step. |
|
label_dict (dict, optional) : a dictionary containing the followings. |
|
"coordinate": the ground-truth coordinates |
|
[..., N_atom, 3] |
|
"coordinate_mask": whether true coordinates exist. |
|
[..., N_atom] |
|
input_feature_dict (dict[str, Any]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
N_sample (int): number of training samples |
|
Returns: |
|
torch.Tensor: the denoised coordinates of x in inference stage |
|
[..., N_sample, N_atom, 3] |
|
""" |
|
batch_size_shape = label_dict["coordinate"].shape[:-2] |
|
device = label_dict["coordinate"].device |
|
dtype = label_dict["coordinate"].dtype |
|
|
|
x_gt_augment = centre_random_augmentation( |
|
x_input_coords=label_dict["coordinate"], |
|
N_sample=N_sample, |
|
mask=label_dict["coordinate_mask"], |
|
).to( |
|
dtype |
|
) |
|
|
|
|
|
|
|
sigma = noise_sampler(size=(*batch_size_shape, N_sample), device=device).to(dtype) |
|
|
|
noise = torch.randn_like(x_gt_augment, dtype=dtype) * sigma[..., None, None] |
|
|
|
|
|
if diffusion_chunk_size is None: |
|
x_denoised = denoise_net( |
|
x_noisy=x_gt_augment + noise, |
|
t_hat_noise_level=sigma, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
) |
|
else: |
|
x_denoised = [] |
|
no_chunks = N_sample // diffusion_chunk_size + ( |
|
N_sample % diffusion_chunk_size != 0 |
|
) |
|
for i in range(no_chunks): |
|
x_noisy_i = (x_gt_augment + noise)[ |
|
..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size, :, : |
|
] |
|
t_hat_noise_level_i = sigma[ |
|
..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size |
|
] |
|
x_denoised_i = denoise_net( |
|
x_noisy=x_noisy_i, |
|
t_hat_noise_level=t_hat_noise_level_i, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
) |
|
x_denoised.append(x_denoised_i) |
|
x_denoised = torch.cat(x_denoised, dim=-3) |
|
|
|
return x_gt_augment, x_denoised, sigma |
|
|
|
|
|
def structure_predictor( |
|
coordinate: torch.Tensor, |
|
denoise_net: Callable, |
|
label_dict: dict[str, Any], |
|
input_feature_dict: dict[str, Any], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
N_sample: int = 1, |
|
diffusion_chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, ...]: |
|
"""Implements diffusion training as described in AF3 Appendix at page 23. |
|
It performances denoising steps from time 0 to time T. |
|
The time steps (=noise levels) are given by noise_schedule. |
|
|
|
Args: |
|
denoise_net (Callable): the network that performs the denoising step. |
|
label_dict (dict, optional) : a dictionary containing the followings. |
|
"coordinate": the ground-truth coordinates |
|
[..., N_atom, 3] |
|
"coordinate_mask": whether true coordinates exist. |
|
[..., N_atom] |
|
input_feature_dict (dict[str, Any]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
N_sample (int): number of training samples |
|
Returns: |
|
torch.Tensor: the denoised coordinates of x in inference stage |
|
[..., N_sample, N_atom, 3] |
|
""" |
|
batch_size_shape = coordinate.shape[:-3] |
|
device = coordinate.device |
|
dtype = coordinate.dtype |
|
|
|
|
|
|
|
sigma = torch.ones(size=(*batch_size_shape, N_sample), device=device).to(dtype) |
|
sigma *= 16 |
|
|
|
if diffusion_chunk_size is None: |
|
x_denoised = denoise_net( |
|
x_noisy=coordinate, |
|
t_hat_noise_level=sigma, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
) |
|
else: |
|
x_denoised = [] |
|
no_chunks = N_sample // diffusion_chunk_size + ( |
|
N_sample % diffusion_chunk_size != 0 |
|
) |
|
for i in range(no_chunks): |
|
x_noisy_i = (coordinate)[ |
|
..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size, :, : |
|
] |
|
t_hat_noise_level_i = sigma[ |
|
..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size |
|
] |
|
x_denoised_i = denoise_net( |
|
x_noisy=x_noisy_i, |
|
t_hat_noise_level=t_hat_noise_level_i, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
) |
|
x_denoised.append(x_denoised_i) |
|
x_denoised = torch.cat(x_denoised, dim=-3) |
|
|
|
return x_denoised, sigma |
|
|
|
|
|
def watermark_decoder( |
|
coordinate: torch.Tensor, |
|
denoise_net: Callable, |
|
input_feature_dict: dict[str, Any], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
N_sample: int = 1, |
|
diffusion_chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, ...]: |
|
"""Implements diffusion training as described in AF3 Appendix at page 23. |
|
It performances denoising steps from time 0 to time T. |
|
The time steps (=noise levels) are given by noise_schedule. |
|
|
|
Args: |
|
denoise_net (Callable): the network that performs the denoising step. |
|
label_dict (dict, optional) : a dictionary containing the followings. |
|
"coordinate": the ground-truth coordinates |
|
[..., N_atom, 3] |
|
"coordinate_mask": whether true coordinates exist. |
|
[..., N_atom] |
|
input_feature_dict (dict[str, Any]): input meta feature dict |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
N_sample (int): number of training samples |
|
Returns: |
|
torch.Tensor: the denoised coordinates of x in inference stage |
|
[..., N_sample, N_atom, 3] |
|
""" |
|
batch_size_shape = coordinate.shape[:-3] |
|
device = coordinate.device |
|
dtype = coordinate.dtype |
|
|
|
sigma = torch.ones(size=(*batch_size_shape, N_sample), device=device).to(dtype) |
|
sigma *= 16 |
|
|
|
if diffusion_chunk_size is None: |
|
a_token = denoise_net( |
|
x_noisy=coordinate, |
|
t_hat_noise_level=sigma, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
) |
|
else: |
|
a_token = [] |
|
no_chunks = N_sample // diffusion_chunk_size + ( |
|
N_sample % diffusion_chunk_size != 0 |
|
) |
|
for i in range(no_chunks): |
|
x_noisy_i = (coordinate)[ |
|
..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size, :, : |
|
] |
|
t_hat_noise_level_i = sigma[ |
|
..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size |
|
] |
|
a_token_i = denoise_net( |
|
x_noisy=x_noisy_i, |
|
t_hat_noise_level=t_hat_noise_level_i, |
|
input_feature_dict=input_feature_dict, |
|
s_inputs=s_inputs, |
|
s_trunk=s_trunk, |
|
z_trunk=z_trunk, |
|
) |
|
a_token.append(a_token_i) |
|
a_token = torch.cat(a_token, dim=-3) |
|
|
|
return coordinate, a_token, sigma |
|
|