Spaces:
Running
Running
from typing import TypeVar | |
import numpy as np | |
import torch | |
from torch import Tensor | |
from esm.utils import residue_constants as RC | |
from esm.utils.structure.affine3d import Affine3D | |
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) | |
def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D: | |
N, CA, C = bb_positions.unbind(dim=-2) | |
return Affine3D.from_graham_schmidt(C, CA, N) | |
def index_by_atom_name( | |
atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 | |
) -> ArrayOrTensor: | |
squeeze = False | |
if isinstance(atom_names, str): | |
atom_names = [atom_names] | |
squeeze = True | |
indices = [RC.atom_order[atom_name] for atom_name in atom_names] | |
dim = dim % atom37.ndim | |
index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) | |
result = atom37[index] # type: ignore | |
if squeeze: | |
result = result.squeeze(dim) | |
return result | |
def get_protein_normalization_frame(coords: Tensor) -> Affine3D: | |
"""Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates. | |
Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame | |
using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame. | |
Args: | |
coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates | |
Returns: | |
Affine3D: tensor of Affine3D frame | |
""" | |
bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2) | |
coord_mask = torch.all( | |
torch.all(torch.isfinite(bb_coords), dim=-1), | |
dim=-1, | |
) | |
average_position_per_n_ca_c = bb_coords.masked_fill( | |
~coord_mask[..., None, None], 0 | |
).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8) | |
frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float()) | |
return frame | |
def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor: | |
"""Given a set of coordinates and a single frame, apply the frame to the coordinates. | |
Args: | |
coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates | |
frame (Affine3D): Affine3D frame | |
Returns: | |
torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates | |
""" | |
coords_trans_rot = frame[..., None, None].invert().apply(coords) | |
# only transform coordinates with frame that have a valid rotation | |
valid_frame = frame.trans.norm(dim=-1) > 0 | |
is_inf = torch.isinf(coords) | |
coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords) | |
coords.masked_fill_(is_inf, torch.inf) | |
return coords | |
def normalize_coordinates(coords: Tensor) -> Tensor: | |
return apply_frame_to_coords(coords, get_protein_normalization_frame(coords)) | |