File size: 3,863 Bytes
89c0b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from protenix.model.utils import batched_gather
def expressCoordinatesInFrame(
coordinate: torch.Tensor, frames: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Algorithm 29 Express coordinate in frame
Args:
coordinate (torch.Tensor): the input coordinate
[..., N_atom, 3]
frames (torch.Tensor): the input frames
[..., N_frame, 3, 3]
eps (float): Small epsilon value
Returns:
torch.Tensor: the transformed coordinate projected onto frame basis
[..., N_frame, N_atom, 3]
"""
# Extract frame atoms
a, b, c = torch.unbind(frames, dim=-2) # a, b, c shape: [..., N_frame, 3]
w1 = F.normalize(a - b, dim=-1, eps=eps)
w2 = F.normalize(c - b, dim=-1, eps=eps)
# Build orthonormal basis
e1 = F.normalize(w1 + w2, dim=-1, eps=eps)
e2 = F.normalize(w2 - w1, dim=-1, eps=eps)
e3 = torch.cross(e1, e2, dim=-1) # [..., N_frame, 3]
# Project onto frame basis
d = coordinate[..., None, :, :] - b[..., None, :] # [..., N_frame, N_atom, 3]
x_transformed = torch.cat(
[
torch.sum(d * e1[..., None, :], dim=-1, keepdim=True),
torch.sum(d * e2[..., None, :], dim=-1, keepdim=True),
torch.sum(d * e3[..., None, :], dim=-1, keepdim=True),
],
dim=-1,
) # [..., N_frame, N_atom, 3]
return x_transformed
def gather_frame_atom_by_indices(
coordinate: torch.Tensor, frame_atom_index: torch.Tensor, dim: int = -2
) -> torch.Tensor:
"""construct frames from coordinate
Args:
coordinate (torch.Tensor): the input coordinate
[..., N_atom, 3]
frame_atom_index (torch.Tensor): indices of three atoms in each frame
[..., N_frame, 3] or [N_frame, 3]
dim (torch.Tensor): along which dimension to select the frame atoms
Returns:
torch.Tensor: the constructed frames
[..., N_frame, 3[three atom], 3[three coordinate]]
"""
if len(frame_atom_index.shape) == 2:
# the navie case
x1 = torch.index_select(
coordinate, dim=dim, index=frame_atom_index[:, 0]
) # [..., N_frame, 3]
x2 = torch.index_select(
coordinate, dim=dim, index=frame_atom_index[:, 1]
) # [..., N_frame, 3]
x3 = torch.index_select(
coordinate, dim=dim, index=frame_atom_index[:, 2]
) # [..., N_frame, 3]
return torch.stack([x1, x2, x3], dim=dim)
else:
assert (
frame_atom_index.shape[:dim] == coordinate.shape[:dim]
), "batch size dims should match"
x1 = batched_gather(
data=coordinate,
inds=frame_atom_index[..., 0],
dim=dim,
no_batch_dims=len(coordinate.shape[:dim]),
) # [..., N_frame, 3]
x2 = batched_gather(
data=coordinate,
inds=frame_atom_index[..., 1],
dim=dim,
no_batch_dims=len(coordinate.shape[:dim]),
) # [..., N_frame, 3]
x3 = batched_gather(
data=coordinate,
inds=frame_atom_index[..., 2],
dim=dim,
no_batch_dims=len(coordinate.shape[:dim]),
) # [..., N_frame, 3]
return torch.stack([x1, x2, x3], dim=dim)
|