File size: 2,306 Bytes
89c0b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from protenix.metrics.rmsd import align_pred_to_true


def get_optimal_transform(
    src_atoms: torch.Tensor,
    tgt_atoms: torch.Tensor,
    mask: torch.Tensor = None,
) -> tuple[torch.Tensor]:
    """
    A function that obtain the transformation that optimally align
    src_atoms to tgt_atoms.

    Args:
        src_atoms: ground-truth centre atom positions, shape: [N, 3]
        tgt_atoms: predicted centre atom positions, shape: [N, 3]
        mask: a vector of boolean values, shape: [N]

    Returns:
        tuple[torch.Tensor]: A rotation matrix that records the optimal rotation
                             that will best align src_atoms to tgt_atoms.
                             A tanslation matrix records how the atoms should be shifted after applying r.
    """
    assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
    assert src_atoms.shape[-1] == 3
    if mask is not None:
        mask = mask.bool()
        assert mask.dim() == 1, "mask should have the shape of [N]"
        assert mask.shape[-1] == src_atoms.shape[-2]
        src_atoms = src_atoms[mask, :]
        tgt_atoms = tgt_atoms[mask, :]

    with torch.cuda.amp.autocast(enabled=False):
        _, rot, trans = align_pred_to_true(
            pred_pose=src_atoms.to(dtype=torch.float32),
            true_pose=tgt_atoms.to(dtype=torch.float32),
            allowing_reflection=False,
        )  # svd alignment does not support BF16

    return rot, trans


def apply_transform(pose, rot, trans):
    return torch.matmul(pose, rot.transpose(-1, -2)) + trans


def num_unique_matches(match_list: list[dict]):
    return len({tuple(sorted(match.items())) for match in match_list})