|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
|
|
|
|
def rmsd( |
|
pred_pose: torch.Tensor, |
|
true_pose: torch.Tensor, |
|
mask: torch.Tensor = None, |
|
eps: float = 0.0, |
|
reduce: bool = True, |
|
): |
|
""" |
|
compute rmsd between two poses, with the same shape |
|
Arguments: |
|
pred_pose/true_pose: [...,N,3], two poses with the same shape |
|
mask: [..., N], mask to indicate which atoms/pseudo_betas/etc to compute |
|
eps: add a tolerance to avoid floating number issue |
|
reduce: decide the return shape of rmsd; |
|
Return: |
|
rmsd: if reduce = true, return the mean of rmsd over batches; |
|
else return a tensor containing each rmsd separately |
|
""" |
|
|
|
|
|
assert pred_pose.shape == true_pose.shape |
|
|
|
if mask is None: |
|
mask = torch.ones(true_pose.shape[:-1], device=true_pose.device) |
|
|
|
|
|
err2 = (torch.square(pred_pose - true_pose).sum(dim=-1) * mask).sum( |
|
dim=-1 |
|
) / mask.sum(dim=-1) |
|
rmsd = err2.add(eps).sqrt() |
|
if reduce: |
|
rmsd = rmsd.mean() |
|
return rmsd |
|
|
|
|
|
def align_pred_to_true( |
|
pred_pose: torch.Tensor, |
|
true_pose: torch.Tensor, |
|
atom_mask: Optional[torch.Tensor] = None, |
|
weight: Optional[torch.Tensor] = None, |
|
allowing_reflection: bool = False, |
|
): |
|
"""Find optimal transformation, rotation (and reflection) of two poses. |
|
Arguments: |
|
pred_pose: [...,N,3] the pose to perform transformation on |
|
true_pose: [...,N,3] the target pose to align pred_pose to |
|
atom_mask: [..., N] a mask for atoms |
|
weight: [..., N] a weight vector to be applied. |
|
allow_reflection: whether to allow reflection when finding optimal alignment |
|
return: |
|
aligned_pose: [...,N,3] the transformed pose |
|
rot: optimal rotation |
|
translate: optimal translation |
|
""" |
|
if atom_mask is not None: |
|
pred_pose = pred_pose * atom_mask.unsqueeze(-1) |
|
true_pose = true_pose * atom_mask.unsqueeze(-1) |
|
else: |
|
atom_mask = torch.ones(*pred_pose.shape[:-1]).to(pred_pose.device) |
|
|
|
if weight is None: |
|
weight = atom_mask |
|
else: |
|
weight = weight * atom_mask |
|
|
|
weighted_n_atoms = torch.sum(weight, dim=-1, keepdim=True).unsqueeze(-1) |
|
pred_pose_centroid = ( |
|
torch.sum(pred_pose * weight.unsqueeze(-1), dim=-2, keepdim=True) |
|
/ weighted_n_atoms |
|
) |
|
pred_pose_centered = pred_pose - pred_pose_centroid |
|
true_pose_centroid = ( |
|
torch.sum(true_pose * weight.unsqueeze(-1), dim=-2, keepdim=True) |
|
/ weighted_n_atoms |
|
) |
|
true_pose_centered = true_pose - true_pose_centroid |
|
H_mat = torch.matmul( |
|
(pred_pose_centered * weight.unsqueeze(-1)).transpose(-2, -1), |
|
true_pose_centered * atom_mask.unsqueeze(-1), |
|
) |
|
u, s, v = torch.svd(H_mat) |
|
u = u.transpose(-1, -2) |
|
|
|
if not allowing_reflection: |
|
|
|
det = torch.linalg.det(torch.matmul(v, u)) |
|
|
|
diagonal = torch.stack( |
|
[torch.ones_like(det), torch.ones_like(det), det], dim=-1 |
|
) |
|
rot = torch.matmul( |
|
torch.diag_embed(diagonal).to(u.device), |
|
u, |
|
) |
|
rot = torch.matmul(v, rot) |
|
else: |
|
rot = torch.matmul(v, u) |
|
translate = true_pose_centroid - torch.matmul( |
|
pred_pose_centroid, rot.transpose(-1, -2) |
|
) |
|
|
|
pred_pose_translated = ( |
|
torch.matmul(pred_pose_centered, rot.transpose(-1, -2)) + true_pose_centroid |
|
) |
|
|
|
return pred_pose_translated, rot, translate |
|
|
|
|
|
def partially_aligned_rmsd( |
|
pred_pose: torch.Tensor, |
|
true_pose: torch.Tensor, |
|
align_mask: torch.Tensor, |
|
atom_mask: torch.Tensor, |
|
weight: Optional[torch.Tensor] = None, |
|
eps: float = 0.0, |
|
reduce: bool = True, |
|
allowing_reflection: bool = False, |
|
): |
|
"""RMSD when aligning parts of the complex coordinate, does NOT take permutation symmetricity into consideration |
|
Arguments: |
|
pred_pose: native predicted pose, [..., N,3] |
|
true_pose: ground truth pose, [..., N, 3] |
|
align_mask: a mask representing which coordinates to align [..., N] |
|
atom_mask: a mask representing which coordinates to compute loss [..., N] |
|
weight: a weight tensor assining weights in alignment for each atom [..., N] |
|
eps: add a tolerance to avoid floating number issue in sqrt |
|
reduce: decide the return shape of rmsd; |
|
allowing_reflection: whether to allow reflection when finding optimal alignment |
|
return: |
|
aligned_part_rmsd: the rmsd of part being align_masked |
|
unaligned_part_rmsd: the rmsd of unaligned part |
|
transformed_pred_pose: |
|
rot: optimal rotation |
|
trans: optimal translation |
|
""" |
|
_, rot, translate = align_pred_to_true( |
|
pred_pose, |
|
true_pose, |
|
atom_mask=atom_mask * align_mask, |
|
weight=weight, |
|
allowing_reflection=allowing_reflection, |
|
) |
|
transformed_pose = torch.matmul(pred_pose, rot.transpose(-1, -2)) + translate |
|
err_atom = torch.square(transformed_pose - true_pose).sum(dim=-1) * atom_mask |
|
aligned_mask, unaligned_mask = atom_mask * align_mask.float(), atom_mask * ( |
|
1 - align_mask.float() |
|
) |
|
aligned_part_err_square = (err_atom * aligned_mask).sum(dim=-1) / aligned_mask.sum( |
|
dim=-1 |
|
) |
|
unaligned_part_err_square = (err_atom * unaligned_mask).sum( |
|
dim=-1 |
|
) / unaligned_mask.sum(dim=-1) |
|
aligned_part_rmsd = aligned_part_err_square.add(eps).sqrt() |
|
unaligned_part_rmsd = unaligned_part_err_square.add(eps).sqrt() |
|
if reduce: |
|
aligned_part_rmsd = aligned_part_rmsd.mean() |
|
unaligned_part_rmsd = unaligned_part_rmsd.mean() |
|
return aligned_part_rmsd, unaligned_part_rmsd, transformed_pose, rot, translate |
|
|
|
|
|
def self_aligned_rmsd( |
|
pred_pose: torch.Tensor, |
|
true_pose: torch.Tensor, |
|
atom_mask: torch.Tensor, |
|
eps: float = 0.0, |
|
reduce: bool = True, |
|
allowing_reflection: bool = False, |
|
): |
|
"""RMSD when aligning one molecule with ground truth and compute rmsd. |
|
Arguments: |
|
pred_pose: native predicted pose, [..., N,3] |
|
true_pose: ground truth pose, [..., N, 3] |
|
atom_mask: a mask representing which coordinates to compute loss [..., N] |
|
eps: add a tolerance to avoid floating number issue in sqrt |
|
reduce: decide the return shape of rmsd; |
|
allowing_reflection: whether to allow reflection when finding optimal alignment |
|
return: |
|
aligned_rmsd: the rmsd of part being align_masked |
|
transformed_pred_pose: the aligned pose |
|
rot: optimal rotation matrix |
|
trans: optimal translation |
|
""" |
|
aligned_rmsd, _, transformed_pred_pose, rot, trans = partially_aligned_rmsd( |
|
pred_pose=pred_pose, |
|
true_pose=true_pose, |
|
align_mask=atom_mask, |
|
atom_mask=atom_mask, |
|
eps=eps, |
|
reduce=reduce, |
|
allowing_reflection=allowing_reflection, |
|
) |
|
return aligned_rmsd, transformed_pred_pose, rot, trans |
|
|
|
|
|
def weighted_rigid_align( |
|
x: torch.Tensor, |
|
x_target: torch.Tensor, |
|
atom_weight: torch.Tensor, |
|
stop_gradient: bool = True, |
|
) -> tuple[torch.Tensor]: |
|
"""Implements Algorithm 28 in AF3. Wrap `align_pred_to_true`. |
|
|
|
Args: |
|
x (torch.Tensor): input coordinates, it will be moved to match x_target. |
|
[..., N_atom, 3] |
|
x_target (torch.Tensor): target coordinates for the input to match. |
|
[..., N_atom, 3] |
|
atom_weight (torch.Tensor): weights for each atom. |
|
[..., N_atom] or [N_atom] |
|
stop_gradient (bool): whether to detach the output. If true, will run it with no_grad() ctx. |
|
|
|
Returns: |
|
x_aligned (torch.Tensor): rotated, translated x which should be closer to x_target. |
|
[..., N_atom, 3] |
|
""" |
|
|
|
if len(atom_weight.shape) == len(x.shape) - 1: |
|
assert atom_weight.shape[:-1] == x.shape[:-2] |
|
else: |
|
assert len(atom_weight.shape) == 1 and atom_weight.shape[-1] == x.shape[-2] |
|
|
|
if stop_gradient: |
|
with torch.no_grad(): |
|
x_aligned, rot, trans = align_pred_to_true( |
|
pred_pose=x, |
|
true_pose=x_target, |
|
atom_mask=None, |
|
weight=atom_weight, |
|
allowing_reflection=False, |
|
) |
|
return x_aligned.detach() |
|
else: |
|
x_aligned, rot, trans = align_pred_to_true( |
|
pred_pose=x, |
|
true_pose=x_target, |
|
atom_mask=None, |
|
weight=atom_weight, |
|
allowing_reflection=False, |
|
) |
|
return x_aligned |
|
|