Spaces:
Running
Running
from __future__ import annotations | |
from dataclasses import replace | |
from typing import TYPE_CHECKING | |
import numpy as np | |
import torch | |
from esm.utils.structure.protein_structure import ( | |
compute_affine_and_rmsd, | |
) | |
if TYPE_CHECKING: | |
from esm.utils.structure.protein_chain import ProteinChain | |
class Aligner: | |
def __init__( | |
self, | |
mobile: ProteinChain, | |
target: ProteinChain, | |
only_use_backbone: bool = False, | |
use_reflection: bool = False, | |
): | |
""" | |
Aligns a mobile protein chain against a target protein chain. | |
Args: | |
mobile (ProteinChain): Protein chain to be aligned. | |
target (ProteinChain): Protein chain target. | |
only_use_backbone (bool): Whether to only use backbone atoms. | |
use_reflection (bool): Whether to align to target reflection. | |
""" | |
# Check proteins must have same number of residues | |
assert len(mobile) == len(target) | |
# Determine overlapping atoms | |
joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype( | |
bool | |
) | |
# Backbone atoms are first sites in atom37 representation | |
if only_use_backbone: | |
joint_atom37_mask[:, 3:] = False | |
# Extract matching atom positions and convert to batched tensors | |
mobile_atom_tensor = ( | |
torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0) | |
) | |
target_atom_tensor = ( | |
torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0) | |
) | |
joint_atom37_mask = ( | |
torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0) | |
) | |
# If using reflection flip target | |
if use_reflection: | |
target_atom_tensor = -target_atom_tensor | |
# Compute alignment and rmsd | |
affine3D, rmsd = compute_affine_and_rmsd( | |
mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask | |
) | |
self._affine3D = affine3D | |
self._rmsd = rmsd.item() | |
def rmsd(self): | |
return self._rmsd | |
def apply(self, mobile: ProteinChain) -> ProteinChain: | |
"""Apply alignment to a protein chain""" | |
# Extract atom positions and convert to batched tensors | |
mobile_atom_tensor = ( | |
torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask]) | |
.type(torch.float32) | |
.unsqueeze(0) | |
) | |
# Transform atom arrays | |
aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0) | |
# Rebuild atom37 positions | |
aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan) | |
aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor | |
return replace(mobile, atom37_positions=aligned_atom37_positions) | |