# Copyright 2024 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional import torch from protenix.model.utils import centre_random_augmentation class TrainingNoiseSampler: """ Sample the noise-level of of training samples """ def __init__( self, p_mean: float = -1.2, p_std: float = 1.5, sigma_data: float = 16.0, # NOTE: in EDM, this is 1.0 ) -> None: """Sampler for training noise-level Args: p_mean (float, optional): gaussian mean. Defaults to -1.2. p_std (float, optional): gaussian std. Defaults to 1.5. sigma_data (float, optional): scale. Defaults to 16.0, but this is 1.0 in EDM. """ self.sigma_data = sigma_data self.p_mean = p_mean self.p_std = p_std print(f"train scheduler {self.sigma_data}") def __call__( self, size: torch.Size, device: torch.device = torch.device("cpu") ) -> torch.Tensor: """Sampling Args: size (torch.Size): the target size device (torch.device, optional): target device. Defaults to torch.device("cpu"). Returns: torch.Tensor: sampled noise-level """ rnd_normal = torch.randn(size=size, device=device) noise_level = (rnd_normal * self.p_std + self.p_mean).exp() * self.sigma_data return noise_level class InferenceNoiseScheduler: """ Scheduler for noise-level (time steps) """ def __init__( self, s_max: float = 160.0, s_min: float = 4e-4, rho: float = 7, sigma_data: float = 16.0, # NOTE: in EDM, this is 1.0 ) -> None: """Scheduler parameters Args: s_max (float, optional): maximal noise level. Defaults to 160.0. s_min (float, optional): minimal noise level. Defaults to 4e-4. rho (float, optional): the exponent numerical part. Defaults to 7. sigma_data (float, optional): scale. Defaults to 16.0, but this is 1.0 in EDM. """ self.sigma_data = sigma_data self.s_max = s_max self.s_min = s_min self.rho = rho print(f"inference scheduler {self.sigma_data}") def __call__( self, N_step: int = 200, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """Schedule the noise-level (time steps). No sampling is performed. Args: N_step (int, optional): number of time steps. Defaults to 200. device (torch.device, optional): target device. Defaults to torch.device("cpu"). dtype (torch.dtype, optional): target dtype. Defaults to torch.float32. Returns: torch.Tensor: noise-level (time_steps) [N_step+1] """ step_size = 1 / N_step step_indices = torch.arange(N_step + 1, device=device, dtype=dtype) t_step_list = ( self.sigma_data * ( self.s_max ** (1 / self.rho) + step_indices * step_size * (self.s_min ** (1 / self.rho) - self.s_max ** (1 / self.rho)) ) ** self.rho ) # replace the last time step by 0 t_step_list[..., -1] = 0 # t_N = 0 return t_step_list def sample_diffusion( denoise_net: Callable, input_feature_dict: dict[str, Any], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, noise_schedule: torch.Tensor, N_sample: int = 1, gamma0: float = 0.8, gamma_min: float = 1.0, noise_scale_lambda: float = 1.003, step_scale_eta: float = 1.5, diffusion_chunk_size: Optional[int] = None, inplace_safe: bool = False, attn_chunk_size: Optional[int] = None, ) -> torch.Tensor: """Implements Algorithm 18 in AF3. It performances denoising steps from time 0 to time T. The time steps (=noise levels) are given by noise_schedule. Args: denoise_net (Callable): the network that performs the denoising step. input_feature_dict (dict[str, Any]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] noise_schedule (torch.Tensor): noise-level schedule (which is also the time steps) since sigma=t. [N_iterations] N_sample (int): number of generated samples gamma0 (float): params in Alg.18. gamma_min (float): params in Alg.18. noise_scale_lambda (float): params in Alg.18. step_scale_eta (float): params in Alg.18. diffusion_chunk_size (Optional[int]): Chunk size for diffusion operation. Defaults to None. inplace_safe (bool): Whether to use inplace operations safely. Defaults to False. attn_chunk_size (Optional[int]): Chunk size for attention operation. Defaults to None. Returns: torch.Tensor: the denoised coordinates of x in inference stage [..., N_sample, N_atom, 3] """ N_atom = input_feature_dict["atom_to_token_idx"].size(-1) batch_shape = s_inputs.shape[:-2] device = s_inputs.device dtype = s_inputs.dtype def _chunk_sample_diffusion(chunk_n_sample, inplace_safe): # init noise # [..., N_sample, N_atom, 3] x_l = noise_schedule[0] * torch.randn( size=(*batch_shape, chunk_n_sample, N_atom, 3), device=device, dtype=dtype ) # NOTE: set seed in distributed training for _, (c_tau_last, c_tau) in enumerate( zip(noise_schedule[:-1], noise_schedule[1:]) ): # [..., N_sample, N_atom, 3] x_l = ( centre_random_augmentation(x_input_coords=x_l, N_sample=1) .squeeze(dim=-3) .to(dtype) ) # Denoise with a predictor-corrector sampler # 1. Add noise to move x_{c_tau_last} to x_{t_hat} gamma = float(gamma0) if c_tau > gamma_min else 0 t_hat = c_tau_last * (gamma + 1) delta_noise_level = torch.sqrt(t_hat**2 - c_tau_last**2) x_noisy = x_l + noise_scale_lambda * delta_noise_level * torch.randn( size=x_l.shape, device=device, dtype=dtype ) # 2. Denoise from x_{t_hat} to x_{c_tau} # Euler step only t_hat = ( t_hat.reshape((1,) * (len(batch_shape) + 1)) .expand(*batch_shape, chunk_n_sample) .to(dtype) ) x_denoised = denoise_net( x_noisy=x_noisy, t_hat_noise_level=t_hat, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, chunk_size=attn_chunk_size, inplace_safe=inplace_safe, ) delta = (x_noisy - x_denoised) / t_hat[ ..., None, None ] # Line 9 of AF3 uses 'x_l_hat' instead, which we believe is a typo. dt = c_tau - t_hat x_l = x_noisy + step_scale_eta * dt[..., None, None] * delta return x_l if diffusion_chunk_size is None: x_l = _chunk_sample_diffusion(N_sample, inplace_safe=inplace_safe) else: x_l = [] no_chunks = N_sample // diffusion_chunk_size + ( N_sample % diffusion_chunk_size != 0 ) for i in range(no_chunks): chunk_n_sample = ( diffusion_chunk_size if i < no_chunks - 1 else N_sample - i * diffusion_chunk_size ) chunk_x_l = _chunk_sample_diffusion( chunk_n_sample, inplace_safe=inplace_safe ) x_l.append(chunk_x_l) x_l = torch.cat(x_l, -3) # [..., N_sample, N_atom, 3] return x_l def sample_diffusion_training( noise_sampler: TrainingNoiseSampler, denoise_net: Callable, label_dict: dict[str, Any], input_feature_dict: dict[str, Any], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, N_sample: int = 1, diffusion_chunk_size: Optional[int] = None, ) -> tuple[torch.Tensor, ...]: """Implements diffusion training as described in AF3 Appendix at page 23. It performances denoising steps from time 0 to time T. The time steps (=noise levels) are given by noise_schedule. Args: denoise_net (Callable): the network that performs the denoising step. label_dict (dict, optional) : a dictionary containing the followings. "coordinate": the ground-truth coordinates [..., N_atom, 3] "coordinate_mask": whether true coordinates exist. [..., N_atom] input_feature_dict (dict[str, Any]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] N_sample (int): number of training samples Returns: torch.Tensor: the denoised coordinates of x in inference stage [..., N_sample, N_atom, 3] """ batch_size_shape = label_dict["coordinate"].shape[:-2] device = label_dict["coordinate"].device dtype = label_dict["coordinate"].dtype # Areate N_sample versions of the input structure by randomly rotating and translating x_gt_augment = centre_random_augmentation( x_input_coords=label_dict["coordinate"], N_sample=N_sample, mask=label_dict["coordinate_mask"], ).to( dtype ) # [..., N_sample, N_atom, 3] # Add independent noise to each structure # sigma: independent noise-level [..., N_sample] sigma = noise_sampler(size=(*batch_size_shape, N_sample), device=device).to(dtype) # noise: [..., N_sample, N_atom, 3] noise = torch.randn_like(x_gt_augment, dtype=dtype) * sigma[..., None, None] # Get denoising outputs [..., N_sample, N_atom, 3] if diffusion_chunk_size is None: x_denoised = denoise_net( x_noisy=x_gt_augment + noise, t_hat_noise_level=sigma, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, ) else: x_denoised = [] no_chunks = N_sample // diffusion_chunk_size + ( N_sample % diffusion_chunk_size != 0 ) for i in range(no_chunks): x_noisy_i = (x_gt_augment + noise)[ ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size, :, : ] t_hat_noise_level_i = sigma[ ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size ] x_denoised_i = denoise_net( x_noisy=x_noisy_i, t_hat_noise_level=t_hat_noise_level_i, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, ) x_denoised.append(x_denoised_i) x_denoised = torch.cat(x_denoised, dim=-3) return x_gt_augment, x_denoised, sigma def structure_predictor( coordinate: torch.Tensor, denoise_net: Callable, label_dict: dict[str, Any], input_feature_dict: dict[str, Any], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, N_sample: int = 1, diffusion_chunk_size: Optional[int] = None, ) -> tuple[torch.Tensor, ...]: """Implements diffusion training as described in AF3 Appendix at page 23. It performances denoising steps from time 0 to time T. The time steps (=noise levels) are given by noise_schedule. Args: denoise_net (Callable): the network that performs the denoising step. label_dict (dict, optional) : a dictionary containing the followings. "coordinate": the ground-truth coordinates [..., N_atom, 3] "coordinate_mask": whether true coordinates exist. [..., N_atom] input_feature_dict (dict[str, Any]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] N_sample (int): number of training samples Returns: torch.Tensor: the denoised coordinates of x in inference stage [..., N_sample, N_atom, 3] """ batch_size_shape = coordinate.shape[:-3] device = coordinate.device dtype = coordinate.dtype # Areate N_sample versions of the input structure by randomly rotating and translating sigma = torch.ones(size=(*batch_size_shape, N_sample), device=device).to(dtype) sigma *= 16 if diffusion_chunk_size is None: x_denoised = denoise_net( x_noisy=coordinate, t_hat_noise_level=sigma, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, ) else: x_denoised = [] no_chunks = N_sample // diffusion_chunk_size + ( N_sample % diffusion_chunk_size != 0 ) for i in range(no_chunks): x_noisy_i = (coordinate)[ ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size, :, : ] t_hat_noise_level_i = sigma[ ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size ] x_denoised_i = denoise_net( x_noisy=x_noisy_i, t_hat_noise_level=t_hat_noise_level_i, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, ) x_denoised.append(x_denoised_i) x_denoised = torch.cat(x_denoised, dim=-3) return x_denoised, sigma def watermark_decoder( coordinate: torch.Tensor, denoise_net: Callable, input_feature_dict: dict[str, Any], s_inputs: torch.Tensor, s_trunk: torch.Tensor, z_trunk: torch.Tensor, N_sample: int = 1, diffusion_chunk_size: Optional[int] = None, ) -> tuple[torch.Tensor, ...]: """Implements diffusion training as described in AF3 Appendix at page 23. It performances denoising steps from time 0 to time T. The time steps (=noise levels) are given by noise_schedule. Args: denoise_net (Callable): the network that performs the denoising step. label_dict (dict, optional) : a dictionary containing the followings. "coordinate": the ground-truth coordinates [..., N_atom, 3] "coordinate_mask": whether true coordinates exist. [..., N_atom] input_feature_dict (dict[str, Any]): input meta feature dict s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder [..., N_tokens, c_s_inputs] s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) [..., N_tokens, c_s] z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) [..., N_tokens, N_tokens, c_z] N_sample (int): number of training samples Returns: torch.Tensor: the denoised coordinates of x in inference stage [..., N_sample, N_atom, 3] """ batch_size_shape = coordinate.shape[:-3] device = coordinate.device dtype = coordinate.dtype sigma = torch.ones(size=(*batch_size_shape, N_sample), device=device).to(dtype) sigma *= 16 if diffusion_chunk_size is None: a_token = denoise_net( x_noisy=coordinate, t_hat_noise_level=sigma, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, ) else: a_token = [] no_chunks = N_sample // diffusion_chunk_size + ( N_sample % diffusion_chunk_size != 0 ) for i in range(no_chunks): x_noisy_i = (coordinate)[ ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size, :, : ] t_hat_noise_level_i = sigma[ ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size ] a_token_i = denoise_net( x_noisy=x_noisy_i, t_hat_noise_level=t_hat_noise_level_i, input_feature_dict=input_feature_dict, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, ) a_token.append(a_token_i) a_token = torch.cat(a_token, dim=-3) return coordinate, a_token, sigma