Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
By contrast, our model learns to reverse a correlated noise process to match the distance statistics of natural proteins,
which have scaling laws that are well understood from biophysics
"""
"""Layers for perturbing protein structure with noise.
This module contains pytorch layers for perturbing protein structure with noise,
which can be useful both for data augmentation, benchmarking, or denoising based
training.
"""
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from tqdm.auto import tqdm
from chroma.constants import AA20
from chroma.data.xcs import validate_XC
from chroma.layers import basic, sde
from chroma.layers.structure import backbone, hbonds, mvn, rmsd
## 高斯噪声
class GaussianNoiseSchedule:
"""
A general noise schedule for the General Gaussian Forward Path, where noise is added
to the input signal.
The noise is modeled as Gaussian noise with mean `alpha_t x_0` and variance
`sigma_t^2`, with 'x_0 ~ p(x_0)' The time range of the noise schedule is
parameterized with a user-specified logarithmic signal-to-noise ratio (SNR) range,
where `snr_t = alpha_t^2 / sigma_t^2` is the SNR at time `t`.
In addition, the object defines a quantity called the scaled signal-to-noise ratio
(`ssnr_t`), which is given by `ssnr_t = alpha_t^2 / (alpha_t^2 + sigma_t^2)`
and is a helpful quantity for analyzing the performance of signal processing
algorithms under different noise conditions.
This object implements a few standard noise schedule:
'log_snr': variance-preserving process with linear log SNR schedule
(https://arxiv.org/abs/2107.00630)
'ot_linear': OT schedule
(https://arxiv.org/abs/2210.02747)
've_log_snr': variance-exploding process with linear log SNR s hedule
(https://arxiv.org/abs/2011.13456 with log SNR noise schedule)
User can also implement their own schedules by specifying alpha_func, sigma_func
and compute_t_range.
"""
def __init__(
self, log_snr_range: Tuple[float, float] = (-7.0, 13.5), kind: str = "log_snr",
) -> None:
super().__init__()
if kind not in ["log_snr", "ot_linear", "ve_log_snr"]:
raise NotImplementedError(
f"noise type {kind} is not implemented, only"
" log_snr and ot_linear are supported "
)
self.kind = kind
self.log_snr_range = log_snr_range
l_min, l_max = self.log_snr_range
# map t \in [0, 1] to match the prescribed log_snr range
self.t_max = self.compute_t_range(l_min)
self.t_min = self.compute_t_range(l_max)
self._eps = 1e-5
def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""map t in [0, 1] to [t_min, t_max]
Args:
t (Union[float, torch.Tensor]): time
Returns:
torch.Tensor: mapped time
"""
if not isinstance(t, torch.Tensor):
t = torch.Tensor([t]).float()
t_max = self.t_max.to(t.device)
t_min = self.t_min.to(t.device)
t_tilde = t_min + (t_max - t_min) * t
return t_tilde
def derivative(self, t: torch.Tensor, func: Callable) -> torch.Tensor:
"""compute derivative of a function, it supports bached single variable inputs
Args:
t (torch.Tensor): time variable at which derivatives are taken
func (Callable): function for derivative calculation
Returns:
torch.Tensor: derivative that is detached from the computational graph
"""
with torch.enable_grad():
t.requires_grad_(True)
derivative = grad(func(t).sum(), t, create_graph=False)[0].detach()
t.requires_grad_(False)
return derivative
def tensor_check(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""convert input to torch.Tensor if it is a float
Args:
t ( Union[float, torch.Tensor]): input
Returns:
torch.Tensor: converted torch.Tensor
"""
if not isinstance(t, torch.Tensor):
t = torch.Tensor([t]).float()
return t
def alpha_func(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""alpha function that scales the mean, usually goes from 1. to 0.
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: alpha value
"""
t = self.tensor_check(t)
if self.kind == "log_snr":
l_min, l_max = self.log_snr_range
t_min, t_max = self.t_min, self.t_max
log_snr = (1 - t) * l_max + t * l_min
log_alpha = 0.5 * (log_snr - F.softplus(log_snr))
alpha = log_alpha.exp()
return alpha
elif self.kind == "ve_log_snr":
return 1 - torch.relu(-t) # make this differentiable
elif self.kind == "ot_linear":
return 1 - t
def sigma_func(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""sigma function that scales the standard deviation, usually goes from 0. to 1.
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: sigma value
"""
t = self.tensor_check(t)
l_min, l_max = self.log_snr_range
if self.kind == "log_snr":
alpha = self.alpha(t)
return (1 - alpha.pow(2)).sqrt()
elif self.kind == "ve_log_snr":
# compute sigma value given snr range
l_min, l_max = self.log_snr_range
t_min, t_max = self.t_min, self.t_max
log_snr = (1 - t) * l_max + t * l_min
return torch.exp(-log_snr / 2)
elif self.kind == "ot_linear":
return t
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""compute alpha value for the mapped time in [t_min, t_max]
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: alpha value
"""
return self.alpha_func(self.t_map(t))
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""compute sigma value for mapped time in [t_min, t_max]
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: sigma value
"""
return self.sigma_func(self.t_map(t))
def alpha_deriv(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""compute alpha derivative for mapped time in [t_min, t_max]
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: time derivative of alpha_func
"""
t_tilde = self.t_map(t)
alpha_deriv_t = self.derivative(t_tilde, self.alpha_func).detach()
return alpha_deriv_t
def sigma_deriv(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""compute sigma derivative for the mapped time in [t_min, t_max]
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: sigma derivative
"""
t_tilde = self.t_map(t)
sigma_deriv_t = self.derivative(t_tilde, self.sigma_func).detach()
return sigma_deriv_t
def beta(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""compute the drift coefficient for the OU process of the form
$dx = -\frac{1}{2} \beta(t) x dt + g(t) dw_t$
Args:
t (Union[float, torch.Tensor]): t in [0, 1]
Returns:
torch.Tensor: beta(t)
"""
# t = self.t_map(t)
alpha = self.alpha(t).detach()
t_map = self.t_map(t)
alpha_deriv_t = self.alpha_deriv(t)
beta = -2.0 * alpha_deriv_t / alpha
return beta
def g(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""compute drift coefficient for the OU process:
$dx = -\frac{1}{2} \beta(t) x dt + g(t) dw_t$
Args:
t (Union[float, torch.Tensor]): t in [0, 1]
Returns:
torch.Tensor: g(t)
"""
if self.kind == "log_snr":
t = self.t_map(t)
g = self.beta(t).sqrt()
else:
alpha_deriv = self.alpha_deriv(t)
alpha_prime_div_alpha = alpha_deriv / self.alpha(t)
sigma_deriv = self.sigma_deriv(t)
sigma_prime_div_sigma = sigma_deriv / self.sigma(t)
g_sq = (
2
* (sigma_deriv - alpha_prime_div_alpha * self.sigma(t))
* self.sigma(t)
)
g = g_sq.sqrt()
return g
def SNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""Signal-to-Noise(SNR) ratio mapped in the allowed log_SNR range
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: SNR value
"""
t = self.tensor_check(t)
if self.kind == "log_snr":
SNR = self.log_SNR(t).exp()
else:
SNR = self.alpha(t).pow(2) / (self.sigma(t).pow(2))
return SNR
def log_SNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""log SNR value
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: log SNR value
"""
t = self.tensor_check(t)
if self.kind == "log_snr":
l_min, l_max = self.log_snr_range
log_snr = (1 - t) * l_max + t * l_min
elif self.kind == "ot_linear":
log_snr = self.SNR(t).log()
return log_snr
def compute_t_range(self, log_snr: Union[float, torch.Tensor]) -> torch.Tensor:
"""Given log(SNR) range : l_max, l_min to compute the time range.
Hand-derivation is required for specific noise schedules.
This function is essentially the inverse of logSNR(t)
Args:
log_snr (Union[float, torch.Tensor]): logSNR value
Returns:
torch.Tensor: the inverse logSNR
"""
log_snr = self.tensor_check(log_snr)
l_min, l_max = self.log_snr_range
if self.kind == "log_snr":
t = (1 / (l_min - l_max)) * (log_snr - l_max)
elif self.kind == "ot_linear":
t = ((0.5 * log_snr).exp() + 1).reciprocal()
elif self.kind == "ve_log_snr":
t = (1 / (l_min - l_max)) * (log_snr - l_max)
return t
def SNR_derivative(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""the derivative of SNR(t)
Args:
t (Union[float, torch.Tensor]): t in [0, 1]
Returns:
torch.Tensor: SNR derivative
"""
t = self.tensor_check(t)
if self.kind == "log_snr":
snr_deriv = self.SNR(t) * (self.log_snr_range[0] - self.log_snr_range[1])
elif self.kind == "ot_linear":
snr_deriv = self.derivative(t, self.SNR)
return snr_deriv
def SSNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
"""Signal to Signal+Noise Ratio (SSNR) = alpha^2 / (alpha^2 + sigma^2)
SSNR monotonically goes from 1 to 0 as t going from 0 to 1.
Args:
t (Union[float, torch.Tensor]): time in [0, 1]
Returns:
torch.Tensor: SSNR value
"""
t = self.tensor_check(t)
return self.SNR(t) / (self.SNR(t) + 1)
def SSNR_inv(self, ssnr: torch.Tensor) -> torch.Tensor:
"""the inverse of SSNR
Args:
ssnr (torch.Tensor): ssnr in [0, 1]
Returns:
torch.Tensor: time in [0, 1]
"""
l_min, l_max = self.log_snr_range
if self.kind == "log_snr":
return ((ssnr / (1 - ssnr)).log() - l_max) / (l_min - l_max)
elif self.kind == "ot_linear":
# the value of SNNR_inv(t=0.5) need to be determined with L'Hôpital rule
# the inver SNNR_function is solved anyltically:
# see woflram alpha result: https://tinyurl.com/bdh4es5a
singularity_check = (ssnr - 0.5).abs() < self._eps
ssnr_mask = singularity_check.float()
ssnr = ssnr_mask * (0.5 + self._eps) + (1.0 - ssnr_mask) * ssnr
return (ssnr + (-ssnr * (ssnr - 1)).sqrt() - 1) / (2 * ssnr - 1)
def SSNR_inv_deriv(self, ssnr: Union[float, torch.Tensor]) -> torch.Tensor:
"""SSNR_inv derivative. SSNR_inv is a CDF like quantity, so its derivative is a PDF-like quantity
Args:
ssnr (Union[float, torch.Tensor]): SSNR in [0, 1]
Returns:
torch.Tensor: derivative of SSNR
"""
ssnr = self.tensor_check(ssnr)
deriv = self.derivative(ssnr, self.SSNR_inv)
return deriv
def prob_SSNR(self, ssnr: Union[float, torch.Tensor]) -> torch.Tensor:
"""compute prob (SSNR(t)), the minus sign is accounted for the inversion of integration range
Args:
ssnr (Union[float, torch.Tensor]): SSNR value
Returns:
torch.Tensor: Prob(SSNR)
"""
return -self.SSNR_inv_deriv(ssnr)
def linear_logsnr_grid(self, N: int, tspan: Tuple[float, float]) -> torch.Tensor:
"""Map uniform time grid to respect logSNR schedule
Args:
N (int): number of steps
tspan (Tuple[float, float]): time span (t_start, t_end)
Returns:
torch.Tensor: time grid as torch.Tensor
"""
logsnr_noise = GaussianNoiseSchedule(
kind="log_snr", log_snr_range=self.log_snr_range
)
ts = torch.linspace(tspan[0], tspan[1], N + 1)
SSNR_vp = logsnr_noise.SSNR(ts)
grid = self.SSNR_inv(SSNR_vp)
# map from t_tilde back to t
grid = (grid - self.t_min) / (self.t_max - self.t_min)
return grid
## 噪声嵌入层
class NoiseTimeEmbedding(nn.Module):
"""
A class that implements a noise time embedding layer.
Args:
dim_embedding (int): The dimension of the output embedding vector.
noise_schedule (GaussianNoiseSchedule): A GaussianNoiseSchedule object that
defines the noise schedule function.
rff_scale (float, optional): The scaling factor for the random Fourier features.
Default is 0.8.
feature_type (str, optional): The type of feature to use for the time embedding.
Either "t" or "log_snr". Default is "log_snr".
Inputs:
t (float): time in (1.0, 0.0).
log_alpha (torch.Tensor, optional): A tensor of log alpha values with
shape `(batch_size,)`.
Outputs:
time_h (torch.Tensor): A tensor of noise time embeddings with shape
`(batch_size, dim_embedding)`.
"""
def __init__(
self,
dim_embedding: int,
noise_schedule: GaussianNoiseSchedule,
rff_scale: float = 0.8,
feature_type: str = "log_snr",
) -> None:
super(NoiseTimeEmbedding, self).__init__()
self.noise_schedule = noise_schedule
self.feature_type = feature_type
self.fourier_features = basic.FourierFeaturization(
d_input=1, d_model=dim_embedding, trainable=False, scale=rff_scale
)
def forward(
self, t: torch.Tensor, log_alpha: Optional[torch.Tensor] = None
) -> torch.Tensor:
if not isinstance(t, torch.Tensor):
t = torch.Tensor([t]).float().to(self.fourier_features.B.device)
if t.dim() == 0:
t = t[None]
h = {"t": lambda: t, "log_snr": lambda: self.noise_schedule.log_SNR(t)}[
self.feature_type
]()
time_h = self.fourier_features(h[:, None, None])
return time_h
## Diffusion
class DiffusionChainCov(nn.Module):
def __init__(
self,
log_snr_range: Tuple[float, float] = (-7.0, 13.5),
noise_schedule: str = "log_snr",
sigma_translation: float = 1.0,
covariance_model: str = "brownian",
complex_scaling: bool = False,
**kwargs,
) -> None:
"""Diffusion backbone noise, with chain-structured covariance.
This class implements a diffusion backbone noise model. The model uses a
chain-structured covariance matrix capturing the spatial correlations between
residues along the backbone. The model also supports different noise schedules
and integration schemes for the stochastic differential equation (SDE) that
defines the diffusion process. This class also implemented various inference
algorithm by reversing the forward diffusion with user-specified
conditioner program.
Args:
log_snr_range (tuple, optional): log SNR range. Defaults to (-7.0, 13.5).
noise_schedule (str, optional): noise schedule type. Defaults to "log_snr".
sigma_translation (float, optional): Scaling factor for the translation
component of the covariance matrix. Defaults to 1.0.
covariance_model (str, optional): covariance mode,. Defaults to "brownian".
complex_scaling (bool, optional): Whether to scale the complex component
of the covariance matrix by the translation component. Defaults to False.
**kwargs: Additional arguments for the base Gaussian distribution and
the SDE integration.
"""
super().__init__()
self.noise_schedule = GaussianNoiseSchedule(
log_snr_range=log_snr_range, kind=noise_schedule,
)
if covariance_model in ["brownian", "globular"]:
self.base_gaussian = mvn.BackboneMVNGlobular(
sigma_translation=sigma_translation,
covariance_model=covariance_model,
complex_scaling=complex_scaling,
)
elif covariance_model == "residue_gas":
self.base_gaussian = mvn.BackboneMVNResidueGas()
self.loss_rmsd = rmsd.BackboneRMSD()
self._eps = 1e-5
self.sde_funcs = {
"langevin": self.langevin,
"reverse_sde": self.reverse_sde,
"ode": self.ode,
}
self.integrate_funcs = {
"euler_maruyama": sde.sde_integrate,
"heun": sde.sde_integrate_heun,
}
def sample_t(
self,
C: torch.LongTensor,
t: Optional[torch.Tensor] = None,
inverse_CDF: Optional[Callable] = None,
) -> torch.Tensor:
"""Sample a random time index for each batch element
Inputs:
C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`.
t (torch.Tensor, optional): Time index with shape `(batch_size,)`.
If not given, a random time index will be sampled. Defaults to None.
Outputs:
t (float): Time index with shape `(batch_size,)`.
"""
if t is not None:
if not isinstance(t, torch.Tensor):
t = torch.Tensor([t]).float()
return t
num_batch = C.size(0)
if self.training:
# Sample correlated but marginally uniform t
# for variance reduction (Kingma et al 2021)
u = torch.rand([])
ix = torch.arange(num_batch) / num_batch
t = torch.remainder(u + ix, 1)
else:
t = torch.rand([num_batch])
if inverse_CDF is not None:
t = inverse_CDF(t)
t = t.to(C.device)
return t
def sde_forward(self, X, C, t, Z=None):
"""Sample an Euler-Maruyama step on forwards SDE.
That is to say, Euler-Maruyama integration would
correspond to the update.
`X_new = X + dt * f + sqrt(dt) * gZ`
Args:
Returns:
f (Tensor): Drift term with shape `()`.
gZ (Tensor): Diffusion term with shape `()`.
"""
# Sample random perturbation
if Z is None:
Z = torch.randn_like(X)
Z = Z.reshape(X.shape[0], -1, 3)
R_Z = self.base_gaussian._multiply_R(Z, C).reshape(X.shape)
X = backbone.center_X(X, C)
beta = self.noise_schedule.beta(t)
f = -beta * X / 2.0
gZ = self.noise_schedule.g(t)[:, None, None] * R_Z
return f, gZ
def _schedule_coefficients(
self,
t: torch.Tensor,
inverse_temperature: float = 1.0,
langevin_isothermal: bool = True,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
A method that computes the schedule coefficients for sampling in the reverse time
Args:
t (float): time in (1.0, 0.0).
inverse_temperature (float, optional): The inverse temperature parameter for
he Langevin dynamics. Default is 1.0.
langevin_isothermal (bool, optional): A flag that indicates whether to use
isothermal or non-isothermal Langevin dynamics. Default is True.
Returns:
alpha (torch.Tensor): A tensor of alpha values with shape `(batch_size, 1, 1)`.
sigma (torch.Tensor): A tensor of sigma values with shape `(batch_size, 1, 1)`.
beta (torch.Tensor): A tensor of beta values with shape `(batch_size, 1, 1)`.
g (torch.Tensor): A tensor of g values with shape `(batch_size, 1, 1)`.
lambda_t (float): A tensor of lambda_t values with shape `(batch_size, 1, 1)`.
lambda_langevin (torch.Tensor): A tensor of lambda_langevin values with
shape `(batch_size, 1, 1)`.
"""
# Schedule coeffiecients
alpha = self.noise_schedule.alpha(t)[:, None, None].to(t.device)
sigma = self.noise_schedule.sigma(t)[:, None, None].to(t.device)
beta = self.noise_schedule.beta(t)[:, None, None].to(t.device)
g = self.noise_schedule.g(t)[:, None, None].to(t.device)
# Temperature coefficients
lambda_t = (
inverse_temperature
* (sigma.pow(2) + alpha.pow(2))
/ (inverse_temperature * sigma.pow(2) + alpha.pow(2))
)
lambda_langevin = inverse_temperature if langevin_isothermal else lambda_t
return alpha, sigma, beta, g, lambda_t, lambda_langevin
@validate_XC()
def langevin(
self,
X: torch.Tensor,
X0_func: Callable,
C: torch.LongTensor,
t: Union[torch.Tensor, float],
conditioner: Callable = None,
Z: Union[torch.Tensor, None] = None,
inverse_temperature: float = 1.0,
langevin_factor: float = 0.0,
langevin_isothermal: bool = True,
align_X0: bool = True,
):
"""Return the drift and diffusion components of the Langevin dynamics for the
reverse process
Args:
X (torch.Tensor): A tensor of protein backbone structure with shape
`(batch_size, num_residues, 4, 3)`.
X0_func (Callable): A function a denoising function for protein backbon
e geometry.
C (torch.LongTensor): A chain map tensor with shape `(batch_size, num_residues)`.
t (float): time in (1.0, 0.0).
conditioner (Callable, optional): A conditioner the performs constrained
transformation (see examples in chroma.layers.structure.conditioners).
Z (torch.Tensor, optional): A tensor of random noise with
shape `(batch_size, num_residues, 4, 3)`. Default is None.
inverse_temperature (float, optional): The inverse temperature parameter
for the Langevin dynamics. Default is 1.0.
langevin_factor (float, optional): The scaling factor for the Langevin noise.
Default is 1.0.
langevin_isothermal (bool, optional): A flag that indicates whether to use
isothermal or non-isothermal Langevin dynamics. Default is True.
align_X0 (bool, optional): A flag that indicates whether to align the noised
X and denoised X for score function calculation.
Returns:
f (torch.Tensor): A tensor of drift terms with shape
`(batch_size, num_residues, 4, 3)`.
gZ (torch.Tensor): A tensor of diffusion terms with shape
`(batch_size, num_residues, 4, 3)`.
"""
alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients(
t,
inverse_temperature=inverse_temperature,
langevin_isothermal=langevin_isothermal,
)
Z = torch.randn_like(X) if Z is None else Z
score = self.score(X, X0_func, C, t, conditioner, align_X0=align_X0)
score_transformed = self.base_gaussian.multiply_covariance(score, C)
f = -g.pow(2) * lambda_langevin * langevin_factor / 2.0 * score_transformed
gZ = g * np.sqrt(langevin_factor) * self.base_gaussian._multiply_R(Z, C)
return f, gZ
@validate_XC()
def reverse_sde(
self,
X: torch.Tensor,
X0_func: Callable,
C: torch.LongTensor,
t: Union[torch.Tensor, float],
conditioner: Callable = None,
Z: Union[torch.Tensor, None] = None,
inverse_temperature: float = 1.0,
langevin_factor: float = 0.0,
langevin_isothermal: bool = True,
align_X0: bool = True,
):
"""Return the drift and diffusion components of the reverse SDE.
Args:
X (torch.Tensor): A tensor of protein backbone structure with shape
`(batch_size, num_residues, 4, 3)`.
X0_func (Callable): A function a denoising function for the protein backbone
geometry.
C (torch.LongTensor): A tensor of condition features with shape
`(batch_size, num_residues)`.
t (float): time in (1.0, 0.0).
conditioner (Callable, optional): A conditioner the performs constrained
transformation (see examples in chroma.layers.structure.conditioners).
Z (torch.Tensor, optional): A tensor of random noise with shape
`(batch_size, num_residues, 4, 3)`. Default is None.
inverse_temperature (float, optional): The inverse temperature parameter
for the Langevin dynamics. Default is 1.0.
langevin_factor (float, optional): The scaling factor for the Langevin noise.
Default is 0.0.
langevin_isothermal (bool, optional): A flag that indicates whether to use
isothermal or non-isothermal Langevin dynamics. Default is True.
align_X0 (bool, optional): A flag that indicates whether to align the noised
X and denoised X for score function calculation.
Returns:
f (torch.Tensor): A tensor of drift terms with shape
`(batch_size, num_residues, 4, 3)`.
gZ (torch.Tensor): A tensor of diffusion terms with shape
`(batch_size, num_residues, 4, 3)`.
"""
# Schedule management
alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients(
t,
inverse_temperature=inverse_temperature,
langevin_isothermal=langevin_isothermal,
)
score_scale_t = lambda_t + lambda_langevin * langevin_factor / 2.0
# Impute missing data
Z = torch.randn_like(X) if Z is None else Z
# X = backbone.center_X(X, C)
score = self.score(X, X0_func, C, t, conditioner, align_X0=align_X0)
score_transformed = self.base_gaussian.multiply_covariance(score, C)
f = (
beta * (-1 / 2) * backbone.center_X(X, C)
- g.pow(2) * score_scale_t * score_transformed
)
gZ = g * np.sqrt(1.0 + langevin_factor) * self.base_gaussian._multiply_R(Z, C)
return f, gZ
@validate_XC()
def ode(
self,
X: torch.Tensor,
X0_func: Callable,
C: torch.LongTensor,
t: Union[torch.Tensor, float],
conditioner: Callable = None,
Z: Union[torch.Tensor, None] = None,
inverse_temperature: float = 1.0,
langevin_factor: float = 0.0,
langevin_isothermal: bool = True,
align_X0: bool = True,
detach_X0: bool = True,
):
"""Return the drift and diffusion components of the probability flow ODE.
Args:
X (torch.Tensor): A tensor of protein backbone structure with shape
`(batch_size, num_residues, 4, 3)`.
X0_func (Callable): A denoising function that returns a protein backbone
geometry `(batch_size, num_residues, 4, 3)`.
C (torch.LongTensor): A tensor of condition features with shape
`(batch_size, num_residues)`.
t (float): time in (1.0, 0.0).
conditioner (Callable, optional): A conditioner the performs constrained
transformation (see examples in chroma.layers.structure.conditioners).
Z (torch.Tensor, optional): A tensor of random noise with shape
`(batch_size, num_residues, 4, 3)`. Default is None.
inverse_temperature (float, optional): The inverse temperature parameter
for the Langevin dynamics. Default is 1.0.
langevin_factor (float, optional): The scaling factor for the Langevin
noise. Default is 0.0.
langevin_isothermal (bool, optional): A flag that indicates whether to use
isothermal or non-isothermal Langevin dynamics. Default is True.
align_X0 (bool, optional): A flag that indicates whether to align
the noised X and denoised X for score function calculation.
Returns:
f (torch.Tensor): A tensor of drift terms with shape
`(batch_size, num_residues, 4, 3)`.
gZ (torch.Tensor): A tensor of diffusion terms with shape
`(batch_size, num_residues, 4, 3)`.
"""
# Schedule management
alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients(
t,
inverse_temperature=inverse_temperature,
langevin_isothermal=langevin_isothermal,
)
# Impute missing data
X = backbone.center_X(X, C)
score = self.score(
X, X0_func, C, t, conditioner, align_X0=align_X0, detach_X0=detach_X0
)
score_transformed = self.base_gaussian.multiply_covariance(score, C)
f = (-1 / 2) * beta * X - 0.5 * lambda_langevin * g.pow(2) * score_transformed
gZ = torch.zeros_like(f)
return f, gZ
@validate_XC()
def energy(
self,
X: torch.Tensor,
X0_func: Callable,
C: torch.Tensor,
t: torch.Tensor,
detach_X0: bool = True,
align_X0: bool = True,
) -> torch.Tensor:
"""Compute the diffusion energy as a function of denoised X
Args:
X (torch.Tensor): A tensor of protein backbone coordinates with shape
`(batch_size, num_residues, 4, 3)`.
X0_func (Callable): A function a denoising function for protein backbone
geometry.
C (torch.LongTensor): A tensor of condition features with shape
`(batch_size, num_residues)`.
t (float): time in (1.0, 0.0).
detach_X0 (bool, optional): A flag that indicates whether to detach the
denoise X for score function evaluation
align_X0 (bool, optional): A flag that indicates whether to align the
noised X and denoised X for score function calculation.
Returns:
U_diffusion (torch.Tensor): A tensor of diffusion energy values with
shape `(batch_size,)`.
"""
X = backbone.impute_masked_X(X, C)
alpha = self.noise_schedule.alpha(t).to(X.device)
sigma = self.noise_schedule.sigma(t).to(X.device)
if detach_X0:
with torch.no_grad():
X0 = X0_func(X, C, t=t)
else:
X0 = X0_func(X, C, t=t)
if align_X0:
X0, _ = self.loss_rmsd.align(X0, X, C, align_unmasked=True)
if detach_X0:
X0 = X0.detach()
Z = self._X_to_Z(X, X0, C, alpha, sigma)
U_diffusion = (0.5 * (Z ** 2)).sum([1, 2, 3])
return U_diffusion
@validate_XC()
def score(
self,
X: torch.Tensor,
X0_func: Callable,
C: torch.Tensor,
t: Union[torch.Tensor, float],
conditioner: Callable = None,
detach_X0: bool = True,
align_X0: bool = True,
U_traj: List = [],
) -> torch.Tensor:
"""Compute the score function
Args:
X (torch.Tensor): A tensor of protein back geometry with shape
`(batch_size, num_residues, 4, 3)`.
X0_func (Callable): A function a denoising function for protein backbone
geometry.
C (torch.LongTensor): A tensor of chain map with shape
`(batch_size, num_residues)`.
t (Union[torch.Tensor, float]): time in (1.0, 0.0).
conditioner (Callable, optional): A conditioner the performs constrained
transformation (see examples in chroma.layers.structure.conditioners).
detach_X0 (bool, optional): A flag that indicates whether to detach the
denoised X for score function evaluation
align_X0 (bool, optional): A flag that indicates whether to align the
noised X and denoised X for score function calculation.
U_traj (List, optional): Record diffusion energy as a list.
Returns:
score (torch.Tensor): A tensor of score values with shape
`(batch_size, num_residues, 4, 3)`.
"""
X = backbone.impute_masked_X(X, C)
with torch.enable_grad():
X = X.detach().clone()
X.requires_grad = True
# Apply optional conditioner transformations to state and energy
Xt, Ct, U_conditioner = X, C, 0.0
St = torch.zeros(Ct.shape, device=Xt.device).long()
Ot = F.one_hot(St, len(AA20)).float()
if conditioner is not None:
Xt, Ct, _, U_conditioner, _ = conditioner(X, C, Ot, U_conditioner, t)
U_conditioner = torch.as_tensor(U_conditioner)
# Compute system energy
U_diffusion = self.energy(
Xt, X0_func, Ct, t, detach_X0=detach_X0, align_X0=align_X0
)
U_traj.append(U_diffusion.detach().cpu())
# Compute score function as negative energy gradient
U_total = U_diffusion.sum() + U_conditioner.sum()
U_total.backward()
score = -X.grad
score = score.masked_fill((C <= 0)[..., None, None], 0.0)
return score
def elbo(self, X0_pred, X0, C, t):
"""ITD ELBO as a weighted average of denoising error,
inspired by https://arxiv.org/abs/2302.03792"""
if not isinstance(t, torch.Tensor):
t = torch.Tensor([t]).float().to(X0.device)
# Interpolate missing data with Brownian Bridge posterior
X0 = backbone.impute_masked_X(X0, C)
X0_pred = backbone.impute_masked_X(X0_pred, C)
# Compute whitened residual
dX = (X0 - X0_pred).reshape([X0.shape[0], -1, 3])
R_inv_dX = self.base_gaussian._multiply_R_inverse(dX, C)
# Average per atom, including over "missing" positions that we filled in
weight = 0.5 * self.noise_schedule.SNR_derivative(t)[:, None, None, None]
snr = self.noise_schedule.SNR(t)[:, None, None, None]
loss_itd = (
weight * (R_inv_dX.pow(2) - 1 / (1 + snr))
- 0.5 * np.log(np.pi * 2.0 * np.e)
).reshape(X0.shape)
# Compute average per-atom loss (including over missing regions)
mask = (C != 0).float()
mask_atoms = mask.reshape(mask.shape + (1, 1)).expand([-1, -1, 4, 1])
# Per-complex
elbo_gap = (mask_atoms * loss_itd).sum([1, 2, 3])
logdet = self.base_gaussian.log_determinant(C)
elbo_unnormalized = elbo_gap - logdet
# Normalize per atom
elbo = elbo_unnormalized / (mask_atoms.sum([1, 2, 3]) + self._eps)
# Compute batch average
weights = mask_atoms.sum([1, 2, 3])
elbo_batch = (weights * elbo).sum() / (weights.sum() + self._eps)
return elbo, elbo_batch
def pseudoelbo(self, loss_per_residue, C, t):
"""Compute pseudo-ELBOs as weighted averages of other errors."""
if not isinstance(t, torch.Tensor):
t = torch.Tensor([t]).float().to(C.device)
# Average per atom, including over x"missing" positions that we filled in
weight = 0.5 * self.noise_schedule.SNR_derivative(t)[:, None]
loss = weight * loss_per_residue
# Compute average loss
mask = (C > 0).float()
pseudoelbo = (mask * loss).sum(-1) / (mask.sum(-1) + self._eps)
pseudoelbo_batch = (mask * loss).sum() / (mask.sum() + self._eps)
return pseudoelbo, pseudoelbo_batch
def _baoab_sample_step(
self,
_x,
p,
C,
t,
dt,
score_func,
gamma=2.0,
kT=1.0,
n_equil=1,
ode_boost=True,
langevin_isothermal=False,
):
gamma = torch.Tensor([gamma]).to(_x.device)
(
alpha,
sigma,
beta,
g,
lambda_t,
lambda_langevin,
) = self._schedule_coefficients(
t, inverse_temperature=1 / kT, langevin_isothermal=langevin_isothermal,
)
def baoab_step(_x, p, t):
Z = torch.randn_like(_x)
c1 = torch.exp(-gamma * dt)
c3 = torch.sqrt((1 / lambda_t) * (1 - c1 ** 2))
# BAOAB scheme
p_half = p + score_func(t, C, _x) * dt / 2 # B
_x_half = (
_x
+ g.pow(2) * self.base_gaussian.multiply_covariance(p_half, C) * dt / 2
) # A
p_half2 = c1 * p_half + c3 * (
1 / g
) * self.base_gaussian._multiply_R_inverse_transpose(
Z, C
) # O
_x = (
_x_half
+ g.pow(2) * self.base_gaussian.multiply_covariance(p_half2, C) * dt / 2
) # A
p = p_half2 + score_func(t, C, _x) * dt / 2 # B
return _x, p
def ode_step(t, _x):
score = score_func(t, C, _x)
score_transformed = self.base_gaussian.multiply_covariance(score, C)
_x = _x + 0.5 * (_x + score_transformed) * g.pow(2) * dt
return _x
for i in range(n_equil):
_x, p = baoab_step(_x, p, t)
if ode_boost:
_x = ode_step(t, _x)
return _x, p
@torch.no_grad()
def sample_sde(
self,
X0_func: Callable,
C: torch.LongTensor,
X_init: Optional[torch.Tensor] = None,
conditioner: Optional[Callable] = None,
N: int = 100,
tspan: Tuple[float, float] = (1.0, 0.001),
inverse_temperature: float = 1.0,
langevin_factor: float = 0.0,
langevin_isothermal: bool = True,
sde_func: str = "reverse_sde",
integrate_func: str = "euler_maruyama",
initialize_noise: bool = True,
remap_time: bool = False,
remove_drift_translate: bool = False,
remove_noise_translate: bool = False,
align_X0: bool = True,
) -> Dict[str, torch.Tensor]:
"""Sample from the SDE using a numerical integration scheme.
This function samples from the stochastic differential equation (SDE) defined
by the model using a numerical integration scheme such as Euler-Maruyama or
huen. The SDE can be either in the forward or reverse direction. The function
also supports optional conditioning on external variables and adding Langevin
noise to the SDE dynamics.
Args:
X0_func (Callable): A denoising function that maps `(X, C, t)` to `X0`.
C (torch.LongTensor): Conditioner tensor with shape `(num_batch,
num_residues)`.
X_init (torch.Tensor, optional): Initial state tensor with shape `(num_batch
, num_residues, 4 ,3)` or None.
If None, a zero tensor will be used as the initial state.
conditioner (Callable, optional): A function that transforms X, C, U, t.
If None, no conditioning will be applied.
N (int): Number of integration steps.
tspan (Tuple[float,float]): Time span for integration.
inverse_temperature (float): Inverse temperature parameter for SDE.
langevin_factor (float): Langevin factor for adding noise to SDE.
langevin_isothermal (bool): Whether to use isothermal or adiabatic Langevin
dynamics.
sde_func (str): Which SDE function to use ('reverse_sde', 'langevin' or 'ode').
integrate_func (str): Which integration function to use ('euler_maruyama'
or 'heun').
initialize_noise (bool): Whether to initialize the state with noise.
remap_time (bool): Whether to remap the time grid according to the noise
schedule.
remove_drift_translate (bool): Whether to remove the net translational
component from the drift term.
remove_noise_translate (bool): Whether to remove the net translational
component from the noise term.
align_X0 (bool): Whether to Kabsch align X0 with X before computing SDE terms.
Returns:
outputs (Dict[str, torch.Tensor]): A dictionary of output tensors with the
following keys:
- 'C': The conditioned tensor with shape `(num_batch,num_residues)`.
- 'X_sample': The final sampled state tensor with shape `(num_batch,
num_residues ,4 ,3)`.
- 'X_trajectory': A list of state tensors along the trajectory with
shape `(num_batch,num_residues ,4 ,3)` each.
- 'Xhat_trajectory': A list of transformed state tensors along the
trajectory with shape `(num_batch,num_residues ,4 ,3)` each.
- 'Xunc_trajectory': A list of unconstrained state tensors along the
trajectory with shape `(num_batch,num_residues ,4 ,3)` each.
"""
# Setup SDE integration
integrate_func = self.integrate_funcs[integrate_func]
sde_func = self.sde_funcs[sde_func]
T_grid = (
self.noise_schedule.linear_logsnr_grid(N=N, tspan=tspan).to(C.device)
if remap_time
else torch.linspace(tspan[0], tspan[1], N + 1).to(C.device)
)
# Intercept the X0 function for tracking Xt and Xhat
Xhat_trajectory = []
Xt_trajectory = []
U_trajectory = []
def _X0_func(_X, _C, t):
_X0 = X0_func(_X, _C, t)
Xt_trajectory.append(_X.detach())
Xhat_trajectory.append(_X0.detach())
return _X0
def sdefun(_t, _X):
f, gZ = sde_func(
_X,
_X0_func,
C,
_t,
conditioner=conditioner,
inverse_temperature=inverse_temperature,
langevin_factor=langevin_factor,
langevin_isothermal=langevin_isothermal,
align_X0=align_X0,
)
# Remove net translational component
if remove_drift_translate:
f = backbone.center_X(f, C)
if remove_noise_translate:
gZ = backbone.center_X(gZ, C)
return f, gZ
# Initialization
if initialize_noise and X_init is not None:
X_init = self.forward(X_init, C, t=tspan[0]).detach()
elif X_init is None:
X_init = torch.zeros(list(C.shape) + [4, 3], device=C.device)
X_init = self.forward(X_init, C, t=tspan[0]).detach()
# Determine output shape via a test forward pass
if conditioner:
with torch.enable_grad():
X_init_test = X_init.clone()
X_init_test.requires_grad = True
S_test = torch.zeros(C.shape, device=X_init.device).long()
O_test = F.one_hot(S_test, len(AA20)).float()
U_test = 0.0
t_test = torch.tensor([0.0], device=X_init.device)
_, Ct, _, _, _ = conditioner(X_init_test, C, O_test, U_test, t_test)
else:
Ct = C
# Integrate
X_trajectory = integrate_func(sdefun, X_init, tspan, N=N, T_grid=T_grid)
# Return constrained coordinates
outputs = {
"C": Ct,
"X_sample": Xt_trajectory[-1],
"X_trajectory": [Xt_trajectory[-1]] + Xt_trajectory,
"Xhat_trajectory": Xhat_trajectory,
"Xunc_trajectory": X_trajectory,
}
return outputs
@torch.no_grad()
def estimate_pseudoelbo_X(
self,
X0_func,
X,
C,
num_samples=50,
deterministic_seed=0,
return_elbo_t=False,
noise=True,
):
with torch.random.fork_rng():
torch.random.manual_seed(deterministic_seed)
mask = (C > 0).float()
mask_atoms = mask.reshape(list(mask.shape) + [1, 1]).expand([-1, -1, 4, 1])
elbo = []
T = np.linspace(1e-4, 1.0, num_samples)
for t in tqdm(T.tolist()):
X_noise = self.forward(X, C, t=t) if noise else X
X_denoise = X0_func(X_noise, C, t)
elbo_t = -self.noise_schedule.SNR_derivative(t).to(X.device) * (
((mask_atoms * (X_denoise - X) / 10.0) ** 2).sum([1, 2, 3])
/ mask_atoms.sum([1, 2, 3])
)
elbo.append(elbo_t)
elbo = torch.stack(elbo, 0)
if not return_elbo_t:
elbo = elbo.mean(0)
return elbo
def _score_direct(
self, Xt, X0_func, C, t, align_X0=True,
):
X0 = X0_func(Xt, C, t)
"""Compute the score function directly. (Sometimes numerically unstable)"""
alpha = self.noise_schedule.alpha(t).to(Xt.device)
sigma = self.noise_schedule.sigma(t).to(Xt.device)
# Impute sensibly behaved values in masked regions for numerical stability
# X0 = backbone.impute_masked_X(X0, C)
Xt = backbone.impute_masked_X(Xt, C)
if align_X0:
X0, _ = self.loss_rmsd.align(X0, Xt, C, align_unmasked=True)
# Compute mean
X_mu = self._mean(X0, C, alpha)
X_mu = backbone.impute_masked_X(X_mu, C)
dX = Xt - X_mu
Ci_dX = self.base_gaussian.multiply_inverse_covariance(dX, C)
score = -Ci_dX / sigma.pow(2)[:, None, None, None]
# Mask
score = score.masked_fill((C <= 0)[..., None, None], 0.0)
return score
def estimate_logp(
self,
X0_func: Callable,
X_sample: torch.Tensor,
C: torch.LongTensor,
N: int,
return_trace_t: bool = False,
):
"""Estimate the model logP for given protein backboones
(num_batch, num_residues, 4, 3) by the Continuous Normalizing Flow formalism
Reference:
https://arxiv.org/abs/1810.01367
https://arxiv.org/abs/1806.07366
Args:
X0_func (Callable): A function that returns the initial protein backboone
(num) features given a condition.
X_sample (torch.Tensor): A tensor of protein backboone (num) features with
shape
`(batch_size, num_residues, 4, 3)`.
C (torch.Tensor): A tensor of condition features with shape `(batch_size,
num_residues)`.
N (int, optional): number of ode integration steps
return_trace_t (bool, optional): A flag that indicates whether to return the
log |df / dx| for each time step for the integrated log Jacobian trance.
Default is False.
Returns:
elbo (torch.Tensor): A tensor of logP value
if return_elbo_t is False, or `(N)` if return_elbo_t
is True.
"""
def divergence(fn, x, t):
"""Calculate Divergance with Stochastic Trace Estimator"""
vec_eps = torch.randn_like(x)
fn_out, eps_J_prod = torch.autograd.functional.vjp(
fn, (t, x), vec_eps, create_graph=False
)
eps_J_eps = (
(eps_J_prod[1] * vec_eps).reshape(x.shape[0], -1).sum(-1).unsqueeze(-1)
)
return fn_out, eps_J_eps
def flow_gradient(
X, X0_func, C, t,
):
"""Compute the time gradient from the probability flow ODE."""
_, _, beta, g, _, _ = self._schedule_coefficients(t)
score = self._score_direct(X, X0_func, C, t)
dXdt = (-1 / 2) * beta * X - 0.5 * g.pow(2) * score
return dXdt
def odefun(_t, _X):
_t = _t.detach()
f = flow_gradient(_X, X0_func, C, _t,)
return f
# foward integration to noise
X_sample = backbone.center_X(X_sample, C)
X_sample = backbone.impute_masked_X(X_sample, C)
C = C.abs()
out = self.sample_sde(
X0_func=X0_func,
C=C,
X_init=X_sample,
N=N,
sde_func="ode",
tspan=(0, 1.0),
inverse_temperature=1.0,
langevin_factor=0.0,
initialize_noise=False,
align_X0=False,
)
X_flow = out["X_trajectory"][1:]
# get ode function
ddlogp = []
for i, t in enumerate(tqdm(torch.linspace(1e-2, 1.0, len(X_flow)))):
with torch.enable_grad():
dlogP = divergence(odefun, X_flow[i], t[None].to(C.device))[1]
ddlogp.append(dlogP.item())
logp_x1 = self.base_gaussian.log_prob(X_flow[-1], C).item()
if return_trace_t:
return np.array(ddlogp) / ((C > 0).float().sum().item() * 4)
else:
return (logp_x1 + np.array(ddlogp).mean()) / (
(C > 0).float().sum().item() * 4
)
@torch.no_grad()
@validate_XC(all_atom=False)
def estimate_elbo(
self,
X0_func: Callable,
X: torch.Tensor,
C: torch.LongTensor,
num_samples: int = 50,
deterministic_seed: int = 0,
return_elbo_t: bool = False,
grad_logprob_Y_func: Optional[Callable] = None,
) -> torch.Tensor:
"""Estimate the evidence lower bound (ELBO) for given protein backboones
(num_batch, num_residues, 4, 3) and condition.
Args:
X0_func (Callable): A function that returns the initial protein backboone
(num) features given a condition.
X (torch.Tensor): A tensor of protein backboone (num) features with shape
`(batch_size, num_residues, 4, 3)`.
C (torch.Tensor): A tensor of condition features with shape `(batch_size,
num_residues)`.
num_samples (int, optional): The number of time steps to sample for
estimating the ELBO. Default is 50.
deterministic_seed (int, optional): The seed for generating random noise.
Default is 0.
return_elbo_t (bool, optional): A flag that indicates whether to return the
ELBO for each time step or the average ELBO. Default is False.
grad_logprob_Y_func (Optional[Callable], optional): A function that returns
the gradient of the log probability of the observed protein backboone (num)
given a time step and a noisy image. Default is None.
Returns:
elbo (torch.Tensor): A tensor of ELBO values with shape `(batch_size,)`
if return_elbo_t is False, or `(num_samples, batch_size)` if return_elbo_t
is True.
"""
X = backbone.impute_masked_X(X, C)
with torch.random.fork_rng():
torch.random.manual_seed(deterministic_seed)
mask = (C > 0).float()
mask_atoms = mask.reshape(list(mask.shape) + [1, 1]).expand([-1, -1, 4, 1])
elbo = []
T = np.linspace(1e-4, 1.0, num_samples)
for t in tqdm(T.tolist()):
X_noise = self.forward(X, C, t=t)
X_denoise = X0_func(X_noise, C, t)
# Adjust X-hat estimate with aux-grad
if grad_logprob_Y_func is not None:
with torch.random.fork_rng():
grad = grad_logprob_Y_func(t, X_noise)
sigma_square = (
self.noise_schedule.sigma(t).square().to(X.device)
)
dXhat = sigma_square * self.base_gaussian.multiply_covariance(
grad, C
)
dXhat = backbone.center_X(dXhat, C)
X_denoise = X_denoise + dXhat
elbo_t, _ = self.elbo(X_denoise, X, C, t)
elbo.append(elbo_t)
elbo_t = torch.stack(elbo, 0)
if return_elbo_t:
return elbo_t
else:
return elbo_t.mean(0)
def conditional_X0(
self, X0: torch.Tensor, score: torch.Tensor, C: torch.tensor, t: torch.Tensor
) -> torch.Tensor:
"""Use Bayes theorem and Tweedie formula to obtain a conditional X0 given
prior X0 and a conditional score \nabla_x p( y | x)
X0 <- X0 + \frac{sigma_t**2}{alpha_t} \Sigma score
Args:
X0 (torch.Tensor): backbone coordinates of size (batch, num_residues, 4, 3)
score (torch.Tensor): of size (batch, num_residues, 4, 3)
C (torch.Tensor): of size (batch, num_residues)
t (torch.Tensor): of size (batch,)
Returns:
X0 (torch.Tensor): updated conditional X0 of size (batch, num_residues, 4, 3)
"""
alpha, sigma, _, _, _, _ = self._schedule_coefficients(t)
X_update = sigma.pow(2).div(alpha)[
..., None
] * self.base_gaussian.multiply_covariance(score, C)
return X0 + X_update
def _mean(self, X, C, alpha):
"""Build the diffusion kernel mean given alpha"""
# Compute the MVN mean
X_mu = backbone.scale_around_mean(X, C, alpha)
return X_mu
def _X_to_Z(self, X_sample, X, C, alpha, sigma):
"""Convert from output space to standardized space"""
# Impute missing data with conditional means
X = backbone.impute_masked_X(X, C)
X_sample = backbone.impute_masked_X(X_sample, C)
# sigma = self.noise_schedule.sigma(t).to(X.device)
# Step 4. [Inverse] Add mean
X_mu = self._mean(X, C, alpha)
X_mu = backbone.impute_masked_X(X_mu, C)
X_noise = (X_sample - X_mu).reshape(X.shape[0], -1, 3)
# Step 3. [Inverse] Scale noise by sigma
X_noise = X_noise / sigma[:, None, None]
# Step 1 & 2. Multiply Z by inverse square root of covariance
Z = self.base_gaussian._multiply_R_inverse(X_noise, C)
return Z
def _Z_to_X(self, Z, X, C, alpha, sigma):
"""Convert from standardized space to output space"""
# Step 1 & 2. Multiply Z by square root of covariance
dX = self.base_gaussian._multiply_R(Z, C)
# Step 3. Scale noise by alpha
dX = sigma[:, None, None, None] * dX.reshape(X.shape)
# Step 4. Add mean
X_mu = self._mean(X, C, alpha)
X_sample = X_mu + dX
return X_sample
def sample_conditional(
self, X: torch.Tensor, C: torch.LongTensor, t: torch.Tensor, s: torch.Tensor
) -> torch.Tensor:
"""
Samples from the forward process q(x_{t} | x_{s}) for t > s.
See appendix A.1 in [https://arxiv.org/pdf/2107.00630.pdf]. `forward` does this for s = 0.
Args:
X (torch.Tensor): Input coordinates with shape `(batch_size, num_residues,
4, 3)` at time `t0`.
C (torch.Tensor): Chain tensor with shape `(batch_size, num_residues)`.
t (torch.Tensor): Time index with shape `(batch_size,)`.
s (torch.Tensor): Time index with shape `(batch_size,)`.
Returns:
X_sample (torch.Tensor): Sampled coordinates from the forward diffusion
marginals with shape `(batch_size, num_residues, 4, 3)`.
"""
assert (t > s).all()
X = backbone.impute_masked_X(X, C)
# Do we need this?
X = backbone.center_X(X, C)
alpha_ts = self.noise_schedule.alpha(t) / self.noise_schedule.alpha(s)
sigma_ts = (
self.noise_schedule.sigma(t).pow(2)
- alpha_ts.pow(2) * self.noise_schedule.sigma(s).pow(2)
).sqrt()
X_sample = alpha_ts * X + sigma_ts * self.base_gaussian.sample(C)
# Do we need this?
X_sample = backbone.center_X(X_sample - X, C) + X
return X_sample
@validate_XC(all_atom=False)
def forward(
self, X: torch.Tensor, C: torch.LongTensor, t: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample from the forwards diffusion marginals at time t
Inputs:
X (torch.Tensor): Input coordinates with shape `(batch_size, num_residues,
4, 3)`.
C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`.
t (torch.Tensor, optional): Time index with shape `(batch_size,)`. If not
given, a random time index will be sampled. Defaults to None.
Outputs:
X_sample (torch.Tensor): Sampled coordinates from the forward diffusion
marginals with shape `(batch_size, num_residues, 4, 3)`.
t (torch.Tensor, optional): Time index with shape `(batch_size,)`. Only
returned if t is not given as input.
"""
# Draw a sample from the prior
X_prior = self.base_gaussian.sample(C)
# Sample time if not given
t_input = t
t = self.sample_t(C, t)
alpha = self.noise_schedule.alpha(t)[:, None, None, None].to(X.device)
sigma = self.noise_schedule.sigma(t)[:, None, None, None].to(X.device)
X_sample = alpha * X + sigma * X_prior
X_sample = backbone.center_X(X_sample - X, C) + X
if t_input is None:
return X_sample, t
else:
return X_sample
## Loss
class ReconstructionLosses(nn.Module):
"""Compute diffusion reconstruction losses for protein backbones.
Args:
diffusion (DiffusionChainCov): Diffusion object parameterizing a
forwards diffusion over protein backbones.
loss_scale (float): Length scale parameter used for setting loss error
scaling in units of Angstroms. Default is 10, which corresponds to
using units of nanometers.
rmsd_method (str): Method used for computing RMSD superpositions. Can
be "symeig" (default) or "power" for power iteration.
Inputs:
X0_pred (torch.Tensor): Denoised coordinates with shape
`(num_batch, num_residues, 4, 3)`.
X (torch.Tensor): Unperturbed coordinates with shape
`(num_batch, num_residues, 4, 3)`.
C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
t (torch.Tensor): Diffusion time with shape `(batch_size,)`.
Should be on [0,1].
Outputs:
losses (dict): Dictionary of reconstructions computed across different
metrics. Metrics prefixed with `batch_` will be batch-averaged scalars
while other metrics should be per batch member with shape
`(num_batch, ...)`.
"""
def __init__(
self,
diffusion: DiffusionChainCov,
loss_scale: float = 10.0,
rmsd_method: str = "symeig",
):
super().__init__()
self.noise_perturb = diffusion
self.loss_scale = loss_scale
self._loss_eps = 1e-5
# Auxiliary losses
self.loss_rmsd = rmsd.BackboneRMSD(method=rmsd_method)
self.loss_fragment = rmsd.LossFragmentRMSD(method=rmsd_method)
self.loss_fragment_pair = rmsd.LossFragmentPairRMSD(method=rmsd_method)
self.loss_neighborhood = rmsd.LossNeighborhoodRMSD(method=rmsd_method)
self.loss_hbond = hbonds.LossBackboneHBonds()
self.loss_distance = backbone.LossBackboneResidueDistance()
self.loss_functions = {
"elbo": self._loss_elbo,
"rmsd": self._loss_rmsd,
"pseudoelbo": self._loss_pseudoelbo,
"fragment": self._loss_fragment,
"pair": self._loss_pair,
"neighborhood": self._loss_neighborhood,
"distance": self._loss_distance,
"hbonds": self._loss_hbonds,
}
def _batch_average(self, loss, C):
weights = (C > 0).float().sum(-1)
return (weights * loss).sum() / (weights.sum() + self._loss_eps)
def _loss_elbo(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
losses["elbo"], losses["batch_elbo"] = self.noise_perturb.elbo(X0_pred, X, C, t)
def _loss_rmsd(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
_, rmsd_denoise = self.loss_rmsd.align(X, X0_pred, C)
_, rmsd_noise = self.loss_rmsd.align(X, X_t_2, C)
rmsd_ratio_per_item = w * rmsd_denoise / (rmsd_noise + self._loss_eps)
global_mse_normalized = (
w
* self.loss_scale
* rmsd_denoise.square()
/ (rmsd_noise.square() + self._loss_eps)
)
losses["rmsd_ratio"] = self._batch_average(rmsd_ratio_per_item, C)
losses["global_mse"] = global_mse_normalized
losses["batch_global_mse"] = self._batch_average(global_mse_normalized, C)
def _loss_pseudoelbo(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
# Unaligned residual pseudoELBO
unaligned_mse = ((X - X0_pred) / self.loss_scale).square().sum(-1).mean(-1)
losses["elbo_X"], losses["batch_pseudoelbo_X"] = self.noise_perturb.pseudoelbo(
unaligned_mse, C, t
)
def _loss_fragment(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
# Aligned Fragment MSE loss
mask = (C > 0).float()
rmsd_fragment = self.loss_fragment(X0_pred, X, C)
rmsd_fragment_noise = self.loss_fragment(X_t_2, X, C)
fragment_mse_normalized = (
self.loss_scale
* w
* (
(mask * rmsd_fragment.square()).sum(1)
/ ((mask * rmsd_fragment_noise.square()).sum(1) + self._loss_eps)
)
)
losses["fragment_mse"] = fragment_mse_normalized
losses["batch_fragment_mse"] = self._batch_average(fragment_mse_normalized, C)
def _loss_pair(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
# Aligned Pair MSE loss
rmsd_pair, mask_ij_pair = self.loss_fragment_pair(X0_pred, X, C)
rmsd_pair_noise, mask_ij_pair = self.loss_fragment_pair(X_t_2, X, C)
pair_mse_normalized = (
self.loss_scale
* w
* (
(mask_ij_pair * rmsd_pair.square()).sum([1, 2])
/ (
(mask_ij_pair * rmsd_pair_noise.square()).sum([1, 2])
+ self._loss_eps
)
)
)
losses["pair_mse"] = pair_mse_normalized
losses["batch_pair_mse"] = self._batch_average(pair_mse_normalized, C)
def _loss_neighborhood(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
# Neighborhood MSE
rmsd_neighborhood, mask = self.loss_neighborhood(X0_pred, X, C)
rmsd_neighborhood_noise, mask = self.loss_neighborhood(X_t_2, X, C)
neighborhood_mse_normalized = (
self.loss_scale
* w
* (
(mask * rmsd_neighborhood.square()).sum(1)
/ ((mask * rmsd_neighborhood_noise.square()).sum(1) + self._loss_eps)
)
)
losses["neighborhood_mse"] = neighborhood_mse_normalized
losses["batch_neighborhood_mse"] = self._batch_average(
neighborhood_mse_normalized, C
)
def _loss_distance(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
# Distance MSE
mask = (C > 0).float()
distance_mse = self.loss_distance(X0_pred, X, C)
distance_mse_noise = self.loss_distance(X_t_2, X, C)
distance_mse_normalized = self.loss_scale * (
w
* (mask * distance_mse).sum(1)
/ ((mask * distance_mse_noise).sum(1) + self._loss_eps)
)
losses["distance_mse"] = distance_mse_normalized
losses["batch_distance_mse"] = self._batch_average(distance_mse_normalized, C)
def _loss_hbonds(self, losses, X0_pred, X, C, t, w=None, X_t_2=None):
# HBond recovery
outs = self.loss_hbond(X0_pred, X, C)
hb_local, hb_nonlocal, error_co = [w * o for o in outs]
losses["batch_hb_local"] = self._batch_average(hb_local, C)
losses["hb_local"] = hb_local
losses["batch_hb_nonlocal"] = self._batch_average(hb_nonlocal, C)
losses["hb_nonlocal"] = hb_nonlocal
losses["batch_hb_contact_order"] = self._batch_average(error_co, C)
@torch.no_grad()
@validate_XC(all_atom=False)
def estimate_metrics(
self,
X0_func: Callable,
X: torch.Tensor,
C: torch.LongTensor,
num_samples: int = 50,
deterministic_seed: int = 0,
use_noise: bool = True,
return_samples: bool = False,
tspan: Tuple[float] = (1e-4, 1.0),
):
"""Estimate time-averaged reconstruction losses of protein backbones.
Args:
X0_func (Callable): A denoising function that maps `(X, C, t)` to `X0`.
X (torch.Tensor): A tensor of protein backboone (num) features with shape
`(batch_size, num_residues, 4, 3)`.
C (torch.Tensor): A tensor of condition features with shape `(batch_size,
num_residues)`.
num_samples (int, optional): The number of time steps to sample for
estimating the ELBO. Default is 50.
use_noise (bool): If True, add noise to each structure before denoising.
Default is True. When False this can be used for estimating if
if structures are fixed points of the denoiser across time.
deterministic_seed (int, optional): The seed for generating random noise.
Default is 0.
return_samples (bool): If True, include intermediate sampled
values for each metric. Default is false.
tspan (tuple[float]): Tuple of floats indicating the diffusion
times between which to integrate.
Returns:
metrics (dict): A dictionary of reconstruction metrics averaged over
time.
metrics_samples (dict, optional): A dictionary of in metrics
averaged over time.
"""
#
X = backbone.impute_masked_X(X, C)
with torch.random.fork_rng():
torch.random.manual_seed(deterministic_seed)
T = np.linspace(1e-4, 1.0, num_samples)
losses = []
for t in tqdm(T.tolist(), desc="Integrating diffusion metrics"):
X_noise = self.noise_perturb(X, C, t=t) if use_noise else X
X_denoise = X0_func(X_noise, C, t)
losses_t = self.forward(X_denoise, X, C, t)
# Discard batch estimated objects
losses_t = {
k: v
for k, v in losses_t.items()
if not k.startswith("batch_") and k != "rmsd_ratio"
}
losses.append(losses_t)
# Transpose list of dicts to a dict of lists
metrics_samples = {k: [d[k] for d in losses] for k in losses[0].keys()}
# Average final metrics across time
metrics = {
k: torch.stack(v, 0).mean(0)
for k, v in metrics_samples.items()
if isinstance(v[0], torch.Tensor)
}
if return_samples:
return metrics, metrics_samples
else:
return metrics
@validate_XC()
def forward(
self,
X0_pred: torch.Tensor,
X: torch.Tensor,
C: torch.LongTensor,
t: torch.Tensor,
):
# Collect all losses and tensors for metric tracking
losses = {"t": t, "X": X, "X0_pred": X0_pred}
X_t_2 = self.noise_perturb(X, C, t=t)
# Per complex weights
ssnr = self.noise_perturb.noise_schedule.SSNR(t).to(X.device)
prob_ssnr = self.noise_perturb.noise_schedule.prob_SSNR(ssnr)
importance_weights = 1 / prob_ssnr
for _loss in self.loss_functions.values():
_loss(losses, X0_pred, X, C, t, w=importance_weights, X_t_2=X_t_2)
return losses
def _debug_viz_gradients(
pml_file, X_list, dX_list, C, S, arrow_length=2.0, name="gradient", color="red"
):
""" """
lines = [
"from pymol.cgo import *",
"from pymol import cmd",
f'color_1 = list(pymol.cmd.get_color_tuple("{color}"))',
'color_2 = list(pymol.cmd.get_color_tuple("blue"))',
]
with open(pml_file, "w") as f:
for model_ix, X in enumerate(X_list):
print(model_ix)
lines = lines + ["obj_1 = []"]
dX = dX_list[model_ix]
scale = dX.norm(dim=-1).mean().item()
X_i = X
X_j = X + arrow_length * dX / scale
for a_ix in range(4):
for i in range(X.size(1)):
x_i = X_i[0, i, a_ix, :].tolist()
x_j = X_j[0, i, a_ix, :].tolist()
lines = lines + [
f"obj_1 = obj_1 + [CYLINDER] + {x_i} + {x_j} + [0.15]"
" + color_1 + color_1"
]
lines = lines + [f'cmd.load_cgo(obj_1, "{name}", {model_ix+1})']
f.write("\n" + "\n".join(lines))
lines = []
def _debug_viz_XZC(X, Z, C, rgb=True):
from matplotlib import pyplot as plt
if len(X.shape) > 3:
X = X.reshape(X.shape[0], -1, 3)
if len(Z.shape) > 3:
Z = Z.reshape(Z.shape[0], -1, 3)
if C.shape[1] != X.shape[1]:
C_expand = C.unsqueeze(-1).expand(-1, -1, 4)
C = C_expand.reshape(C.shape[0], -1)
# C_mask = expand_chain_map(torch.abs(C))
# X_expand = torch.einsum('nix,nic->nicx', X, C_mask)
# plt.plot(X_expand[0,:,:,0].data.numpy())
N = X.shape[1]
Ymax = torch.max(X[0, :, 0]).item()
plt.figure(figsize=[12, 4])
plt.subplot(2, 1, 1)
plt.bar(
np.arange(0, N),
(C[0, :].data.numpy() < 0) * Ymax,
width=1.0,
edgecolor=None,
color="lightgrey",
)
if rgb:
plt.plot(X[0, :, 0].data.numpy(), "r", linewidth=0.5)
plt.plot(X[0, :, 1].data.numpy(), "g", linewidth=0.5)
plt.plot(X[0, :, 2].data.numpy(), "b", linewidth=0.5)
plt.xlim([0, N])
plt.grid()
plt.title("X")
plt.xticks([])
plt.subplot(2, 1, 2)
plt.plot(Z[0, :, 0].data.numpy(), "r", linewidth=0.5)
plt.plot(Z[0, :, 1].data.numpy(), "g", linewidth=0.5)
plt.plot(Z[0, :, 2].data.numpy(), "b", linewidth=0.5)
plt.plot(C[0, :].data.numpy(), "orange")
plt.xlim([0, N])
plt.grid()
plt.title("RInverse @ [X]")
plt.xticks([])
plt.savefig("xzc.pdf")
else:
plt.plot(X[0, :, 0].data.numpy(), "k", linewidth=0.5)
plt.xlim([0, N])
plt.grid()
plt.title("X")
plt.xticks([])
plt.subplot(2, 1, 2)
plt.plot(Z[0, :, 0].data.numpy(), "k", linewidth=0.5)
plt.plot(C[0, :].data.numpy(), "orange")
plt.xlim([0, N])
plt.grid()
plt.title("Inverse[X]")
plt.xticks([])
plt.savefig("xzc.pdf")
exit()