M3Site / esm /utils /structure /protein_structure.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
from __future__ import annotations
from typing import Tuple, TypeVar
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import autocast # type: ignore
from esm.utils import residue_constants
from esm.utils.misc import unbinpack
from esm.utils.structure.affine3d import Affine3D
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
def index_by_atom_name(
atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
) -> ArrayOrTensor:
squeeze = False
if isinstance(atom_names, str):
atom_names = [atom_names]
squeeze = True
indices = [residue_constants.atom_order[atom_name] for atom_name in atom_names]
dim = dim % atom37.ndim
index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
result = atom37[index] # type: ignore
if squeeze:
result = result.squeeze(dim)
return result
def infer_cbeta_from_atom37(
atom37: ArrayOrTensor, L: float = 1.522, A: float = 1.927, D: float = -2.143
):
"""
Inspired by a util in trDesign:
https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92
input: atom37, (L)ength, (A)ngle, and (D)ihedral
output: 4th coord
"""
N = index_by_atom_name(atom37, "N", dim=-2)
CA = index_by_atom_name(atom37, "CA", dim=-2)
C = index_by_atom_name(atom37, "C", dim=-2)
if isinstance(atom37, np.ndarray):
def normalize(x: ArrayOrTensor):
return x / np.linalg.norm(x, axis=-1, keepdims=True)
cross = np.cross
else:
normalize = F.normalize # type: ignore
cross = torch.cross
with np.errstate(invalid="ignore"): # inf - inf = nan is ok here
vec_nca = N - CA
vec_nc = N - C
nca = normalize(vec_nca)
n = normalize(cross(vec_nc, nca)) # type: ignore
m = [nca, cross(n, nca), n]
d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
return CA + sum([m * d for m, d in zip(m, d)])
@torch.no_grad()
@autocast(enabled=False)
def compute_alignment_tensors(
mobile: torch.Tensor,
target: torch.Tensor,
atom_exists_mask: torch.Tensor | None = None,
sequence_id: torch.Tensor | None = None,
):
"""
Align two batches of structures with support for masking invalid atoms using PyTorch.
Args:
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
Returns:
- centered_mobile (torch.Tensor): Batch of coordinates of structure centered mobile (B, N, 3)
- centroid_mobile (torch.Tensor): Batch of coordinates of mobile centeroid (B, 3)
- centered_target (torch.Tensor): Batch of coordinates of structure centered target (B, N, 3)
- centroid_target (torch.Tensor): Batch of coordinates of target centeroid (B, 3)
- rotation_matrix (torch.Tensor): Batch of coordinates of rotation matrix (B, 3, 3)
- num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,)
"""
# Ensure both batches have the same number of structures, atoms, and dimensions
if sequence_id is not None:
mobile = unbinpack(mobile, sequence_id, pad_value=torch.nan)
target = unbinpack(target, sequence_id, pad_value=torch.nan)
if atom_exists_mask is not None:
atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=0)
else:
atom_exists_mask = torch.isfinite(target).all(-1)
assert mobile.shape == target.shape, "Batch structure shapes do not match!"
# Number of structures in the batch
batch_size = mobile.shape[0]
# if [B, Nres, Natom, 3], resize
if mobile.dim() == 4:
mobile = mobile.view(batch_size, -1, 3)
if target.dim() == 4:
target = target.view(batch_size, -1, 3)
if atom_exists_mask is not None and atom_exists_mask.dim() == 3:
atom_exists_mask = atom_exists_mask.view(batch_size, -1)
# Number of atoms
num_atoms = mobile.shape[1]
# Apply masks if provided
if atom_exists_mask is not None:
mobile = mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
target = target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
else:
atom_exists_mask = torch.ones(
batch_size, num_atoms, dtype=torch.bool, device=mobile.device
)
num_valid_atoms = atom_exists_mask.sum(dim=-1, keepdim=True)
# Compute centroids for each batch
centroid_mobile = mobile.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1)
centroid_target = target.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1)
# Handle potential division by zero if all atoms are invalid in a structure
centroid_mobile[num_valid_atoms == 0] = 0
centroid_target[num_valid_atoms == 0] = 0
# Center structures by subtracting centroids
centered_mobile = mobile - centroid_mobile
centered_target = target - centroid_target
centered_mobile = centered_mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
centered_target = centered_target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
# Compute covariance matrix for each batch
covariance_matrix = torch.matmul(centered_mobile.transpose(1, 2), centered_target)
# Singular Value Decomposition for each batch
u, _, v = torch.svd(covariance_matrix)
# Calculate rotation matrices for each batch
rotation_matrix = torch.matmul(u, v.transpose(1, 2))
return (
centered_mobile,
centroid_mobile,
centered_target,
centroid_target,
rotation_matrix,
num_valid_atoms,
)
@torch.no_grad()
@autocast(enabled=False)
def compute_rmsd_no_alignment(
aligned: torch.Tensor,
target: torch.Tensor,
num_valid_atoms: torch.Tensor,
reduction: str = "batch",
) -> torch.Tensor:
"""
Compute RMSD between two batches of structures without alignment.
Args:
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
- num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,)
- reduction (str): One of "batch", "per_sample", "per_residue".
Returns:
If reduction == "batch":
(torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch
If reduction == "per_sample":
(torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch
If reduction == "per_residue":
(torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch
"""
if reduction not in ("per_residue", "per_sample", "batch"):
raise ValueError("Unrecognized reduction: '{reduction}'")
# Compute RMSD for each batch
diff = aligned - target
if reduction == "per_residue":
mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1)
else:
mean_squared_error = diff.square().sum(dim=(1, 2)) / (
num_valid_atoms.squeeze(-1) * 3
)
rmsd = torch.sqrt(mean_squared_error)
if reduction in ("per_sample", "per_residue"):
return rmsd
elif reduction == "batch":
avg_rmsd = rmsd.masked_fill(num_valid_atoms.squeeze(-1) == 0, 0).sum() / (
(num_valid_atoms > 0).sum() + 1e-8
)
return avg_rmsd
else:
raise ValueError(reduction)
@torch.no_grad()
@autocast(enabled=False)
def compute_affine_and_rmsd(
mobile: torch.Tensor,
target: torch.Tensor,
atom_exists_mask: torch.Tensor | None = None,
sequence_id: torch.Tensor | None = None,
) -> Tuple[Affine3D, torch.Tensor]:
"""
Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch.
Args:
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
Returns:
- affine (Affine3D): Transformation between mobile and target structure
- avg_rmsd (torch.Tensor): Average Root Mean Square Deviation between the structures for each batch
"""
(
centered_mobile,
centroid_mobile,
centered_target,
centroid_target,
rotation_matrix,
num_valid_atoms,
) = compute_alignment_tensors(
mobile=mobile,
target=target,
atom_exists_mask=atom_exists_mask,
sequence_id=sequence_id,
)
# Apply rotation to mobile centroid
translation = torch.matmul(-centroid_mobile, rotation_matrix) + centroid_target
affine = Affine3D.from_tensor_pair(
translation, rotation_matrix.unsqueeze(dim=-3).transpose(-2, -1)
)
# Apply transformation to centered structure to compute rmsd
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
avg_rmsd = compute_rmsd_no_alignment(
rotated_mobile,
centered_target,
num_valid_atoms,
reduction="batch",
)
return affine, avg_rmsd