from __future__ import annotations from typing import Tuple, TypeVar import numpy as np import torch import torch.nn.functional as F from torch import Tensor from torch.cuda.amp import autocast # type: ignore from esm.utils import residue_constants from esm.utils.misc import unbinpack from esm.utils.structure.affine3d import Affine3D ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) def index_by_atom_name( atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 ) -> ArrayOrTensor: squeeze = False if isinstance(atom_names, str): atom_names = [atom_names] squeeze = True indices = [residue_constants.atom_order[atom_name] for atom_name in atom_names] dim = dim % atom37.ndim index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) result = atom37[index] # type: ignore if squeeze: result = result.squeeze(dim) return result def infer_cbeta_from_atom37( atom37: ArrayOrTensor, L: float = 1.522, A: float = 1.927, D: float = -2.143 ): """ Inspired by a util in trDesign: https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92 input: atom37, (L)ength, (A)ngle, and (D)ihedral output: 4th coord """ N = index_by_atom_name(atom37, "N", dim=-2) CA = index_by_atom_name(atom37, "CA", dim=-2) C = index_by_atom_name(atom37, "C", dim=-2) if isinstance(atom37, np.ndarray): def normalize(x: ArrayOrTensor): return x / np.linalg.norm(x, axis=-1, keepdims=True) cross = np.cross else: normalize = F.normalize # type: ignore cross = torch.cross with np.errstate(invalid="ignore"): # inf - inf = nan is ok here vec_nca = N - CA vec_nc = N - C nca = normalize(vec_nca) n = normalize(cross(vec_nc, nca)) # type: ignore m = [nca, cross(n, nca), n] d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] return CA + sum([m * d for m, d in zip(m, d)]) @torch.no_grad() @autocast(enabled=False) def compute_alignment_tensors( mobile: torch.Tensor, target: torch.Tensor, atom_exists_mask: torch.Tensor | None = None, sequence_id: torch.Tensor | None = None, ): """ Align two batches of structures with support for masking invalid atoms using PyTorch. Args: - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. Returns: - centered_mobile (torch.Tensor): Batch of coordinates of structure centered mobile (B, N, 3) - centroid_mobile (torch.Tensor): Batch of coordinates of mobile centeroid (B, 3) - centered_target (torch.Tensor): Batch of coordinates of structure centered target (B, N, 3) - centroid_target (torch.Tensor): Batch of coordinates of target centeroid (B, 3) - rotation_matrix (torch.Tensor): Batch of coordinates of rotation matrix (B, 3, 3) - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) """ # Ensure both batches have the same number of structures, atoms, and dimensions if sequence_id is not None: mobile = unbinpack(mobile, sequence_id, pad_value=torch.nan) target = unbinpack(target, sequence_id, pad_value=torch.nan) if atom_exists_mask is not None: atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=0) else: atom_exists_mask = torch.isfinite(target).all(-1) assert mobile.shape == target.shape, "Batch structure shapes do not match!" # Number of structures in the batch batch_size = mobile.shape[0] # if [B, Nres, Natom, 3], resize if mobile.dim() == 4: mobile = mobile.view(batch_size, -1, 3) if target.dim() == 4: target = target.view(batch_size, -1, 3) if atom_exists_mask is not None and atom_exists_mask.dim() == 3: atom_exists_mask = atom_exists_mask.view(batch_size, -1) # Number of atoms num_atoms = mobile.shape[1] # Apply masks if provided if atom_exists_mask is not None: mobile = mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) target = target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) else: atom_exists_mask = torch.ones( batch_size, num_atoms, dtype=torch.bool, device=mobile.device ) num_valid_atoms = atom_exists_mask.sum(dim=-1, keepdim=True) # Compute centroids for each batch centroid_mobile = mobile.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) centroid_target = target.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) # Handle potential division by zero if all atoms are invalid in a structure centroid_mobile[num_valid_atoms == 0] = 0 centroid_target[num_valid_atoms == 0] = 0 # Center structures by subtracting centroids centered_mobile = mobile - centroid_mobile centered_target = target - centroid_target centered_mobile = centered_mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) centered_target = centered_target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) # Compute covariance matrix for each batch covariance_matrix = torch.matmul(centered_mobile.transpose(1, 2), centered_target) # Singular Value Decomposition for each batch u, _, v = torch.svd(covariance_matrix) # Calculate rotation matrices for each batch rotation_matrix = torch.matmul(u, v.transpose(1, 2)) return ( centered_mobile, centroid_mobile, centered_target, centroid_target, rotation_matrix, num_valid_atoms, ) @torch.no_grad() @autocast(enabled=False) def compute_rmsd_no_alignment( aligned: torch.Tensor, target: torch.Tensor, num_valid_atoms: torch.Tensor, reduction: str = "batch", ) -> torch.Tensor: """ Compute RMSD between two batches of structures without alignment. Args: - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) - reduction (str): One of "batch", "per_sample", "per_residue". Returns: If reduction == "batch": (torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch If reduction == "per_sample": (torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch If reduction == "per_residue": (torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch """ if reduction not in ("per_residue", "per_sample", "batch"): raise ValueError("Unrecognized reduction: '{reduction}'") # Compute RMSD for each batch diff = aligned - target if reduction == "per_residue": mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1) else: mean_squared_error = diff.square().sum(dim=(1, 2)) / ( num_valid_atoms.squeeze(-1) * 3 ) rmsd = torch.sqrt(mean_squared_error) if reduction in ("per_sample", "per_residue"): return rmsd elif reduction == "batch": avg_rmsd = rmsd.masked_fill(num_valid_atoms.squeeze(-1) == 0, 0).sum() / ( (num_valid_atoms > 0).sum() + 1e-8 ) return avg_rmsd else: raise ValueError(reduction) @torch.no_grad() @autocast(enabled=False) def compute_affine_and_rmsd( mobile: torch.Tensor, target: torch.Tensor, atom_exists_mask: torch.Tensor | None = None, sequence_id: torch.Tensor | None = None, ) -> Tuple[Affine3D, torch.Tensor]: """ Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch. Args: - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. Returns: - affine (Affine3D): Transformation between mobile and target structure - avg_rmsd (torch.Tensor): Average Root Mean Square Deviation between the structures for each batch """ ( centered_mobile, centroid_mobile, centered_target, centroid_target, rotation_matrix, num_valid_atoms, ) = compute_alignment_tensors( mobile=mobile, target=target, atom_exists_mask=atom_exists_mask, sequence_id=sequence_id, ) # Apply rotation to mobile centroid translation = torch.matmul(-centroid_mobile, rotation_matrix) + centroid_target affine = Affine3D.from_tensor_pair( translation, rotation_matrix.unsqueeze(dim=-3).transpose(-2, -1) ) # Apply transformation to centered structure to compute rmsd rotated_mobile = torch.matmul(centered_mobile, rotation_matrix) avg_rmsd = compute_rmsd_no_alignment( rotated_mobile, centered_target, num_valid_atoms, reduction="batch", ) return affine, avg_rmsd