File size: 2,306 Bytes
89c0b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from protenix.metrics.rmsd import align_pred_to_true
def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
mask: torch.Tensor = None,
) -> tuple[torch.Tensor]:
"""
A function that obtain the transformation that optimally align
src_atoms to tgt_atoms.
Args:
src_atoms: ground-truth centre atom positions, shape: [N, 3]
tgt_atoms: predicted centre atom positions, shape: [N, 3]
mask: a vector of boolean values, shape: [N]
Returns:
tuple[torch.Tensor]: A rotation matrix that records the optimal rotation
that will best align src_atoms to tgt_atoms.
A tanslation matrix records how the atoms should be shifted after applying r.
"""
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3
if mask is not None:
mask = mask.bool()
assert mask.dim() == 1, "mask should have the shape of [N]"
assert mask.shape[-1] == src_atoms.shape[-2]
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
with torch.cuda.amp.autocast(enabled=False):
_, rot, trans = align_pred_to_true(
pred_pose=src_atoms.to(dtype=torch.float32),
true_pose=tgt_atoms.to(dtype=torch.float32),
allowing_reflection=False,
) # svd alignment does not support BF16
return rot, trans
def apply_transform(pose, rot, trans):
return torch.matmul(pose, rot.transpose(-1, -2)) + trans
def num_unique_matches(match_list: list[dict]):
return len({tuple(sorted(match.items())) for match in match_list})
|