|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from protenix.model.utils import batched_gather |
|
|
|
|
|
def expressCoordinatesInFrame( |
|
coordinate: torch.Tensor, frames: torch.Tensor, eps: float = 1e-8 |
|
) -> torch.Tensor: |
|
"""Algorithm 29 Express coordinate in frame |
|
|
|
Args: |
|
coordinate (torch.Tensor): the input coordinate |
|
[..., N_atom, 3] |
|
frames (torch.Tensor): the input frames |
|
[..., N_frame, 3, 3] |
|
eps (float): Small epsilon value |
|
|
|
Returns: |
|
torch.Tensor: the transformed coordinate projected onto frame basis |
|
[..., N_frame, N_atom, 3] |
|
""" |
|
|
|
a, b, c = torch.unbind(frames, dim=-2) |
|
w1 = F.normalize(a - b, dim=-1, eps=eps) |
|
w2 = F.normalize(c - b, dim=-1, eps=eps) |
|
|
|
e1 = F.normalize(w1 + w2, dim=-1, eps=eps) |
|
e2 = F.normalize(w2 - w1, dim=-1, eps=eps) |
|
e3 = torch.cross(e1, e2, dim=-1) |
|
|
|
d = coordinate[..., None, :, :] - b[..., None, :] |
|
x_transformed = torch.cat( |
|
[ |
|
torch.sum(d * e1[..., None, :], dim=-1, keepdim=True), |
|
torch.sum(d * e2[..., None, :], dim=-1, keepdim=True), |
|
torch.sum(d * e3[..., None, :], dim=-1, keepdim=True), |
|
], |
|
dim=-1, |
|
) |
|
return x_transformed |
|
|
|
|
|
def gather_frame_atom_by_indices( |
|
coordinate: torch.Tensor, frame_atom_index: torch.Tensor, dim: int = -2 |
|
) -> torch.Tensor: |
|
"""construct frames from coordinate |
|
|
|
Args: |
|
coordinate (torch.Tensor): the input coordinate |
|
[..., N_atom, 3] |
|
frame_atom_index (torch.Tensor): indices of three atoms in each frame |
|
[..., N_frame, 3] or [N_frame, 3] |
|
dim (torch.Tensor): along which dimension to select the frame atoms |
|
Returns: |
|
torch.Tensor: the constructed frames |
|
[..., N_frame, 3[three atom], 3[three coordinate]] |
|
""" |
|
if len(frame_atom_index.shape) == 2: |
|
|
|
x1 = torch.index_select( |
|
coordinate, dim=dim, index=frame_atom_index[:, 0] |
|
) |
|
x2 = torch.index_select( |
|
coordinate, dim=dim, index=frame_atom_index[:, 1] |
|
) |
|
x3 = torch.index_select( |
|
coordinate, dim=dim, index=frame_atom_index[:, 2] |
|
) |
|
return torch.stack([x1, x2, x3], dim=dim) |
|
else: |
|
assert ( |
|
frame_atom_index.shape[:dim] == coordinate.shape[:dim] |
|
), "batch size dims should match" |
|
|
|
x1 = batched_gather( |
|
data=coordinate, |
|
inds=frame_atom_index[..., 0], |
|
dim=dim, |
|
no_batch_dims=len(coordinate.shape[:dim]), |
|
) |
|
x2 = batched_gather( |
|
data=coordinate, |
|
inds=frame_atom_index[..., 1], |
|
dim=dim, |
|
no_batch_dims=len(coordinate.shape[:dim]), |
|
) |
|
x3 = batched_gather( |
|
data=coordinate, |
|
inds=frame_atom_index[..., 2], |
|
dim=dim, |
|
no_batch_dims=len(coordinate.shape[:dim]), |
|
) |
|
return torch.stack([x1, x2, x3], dim=dim) |
|
|