File size: 5,177 Bytes
89c0b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import traceback
from typing import Union

import torch

from protenix.utils.logger import get_logger
from protenix.utils.permutation.utils import save_permutation_error

from .heuristic import correct_symmetric_chains
from .pocket_based_permutation import permute_pred_to_optimize_pocket_aligned_rmsd

logger = get_logger(__name__)


def run(
    pred_coord: torch.Tensor,
    input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]],
    label_full_dict: dict[str, Union[torch.Tensor, int, float, dict]],
    max_num_chains: int = -1,
    permute_label: bool = True,
    permute_by_pocket: bool = False,
    error_dir: str = None,
    **kwargs,
) -> tuple[dict]:
    """
    Run chain permutation.


    Args:
        pred_coord (torch.Tensor): The predicted coordinates. Shape: [N_atoms, 3].
        input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): A dictionary containing input features.
        label_full_dict (dict[str, Union[torch.Tensor, int, float, dict]]): A dictionary containing full label information.
        max_num_chains (int, optional): The maximum number of chains to consider. Defaults to -1 (no limit).
        permute_label (bool, optional): Whether to permute the label. Defaults to True.
        permute_by_pocket (bool, optional): Whether to permute by pocket (for PoseBusters dataset). Defaults to False.
        error_dir (str, optional): Directory to save error data. Defaults to None.
        **kwargs: Additional keyword arguments.

    Returns:
        tuple[dict]: A tuple containing the output dictionary, log dictionary, permuted prediction indices, and permuted label indices.
    """

    if pred_coord.dim() > 2:
        assert (
            permute_label is False
        ), "Only supports prediction permutations in batch mode."

    try:

        if permute_by_pocket:
            """Optimize the chain assignment on pocket-ligand interface"""
            assert not permute_label

            if label_full_dict["pocket_mask"].dim() == 2:
                # first pocket is the `main` pocket
                pocket_mask = label_full_dict["pocket_mask"][0]
                ligand_mask = label_full_dict["interested_ligand_mask"][0]
            else:
                pocket_mask = label_full_dict["pocket_mask"]
                ligand_mask = label_full_dict["interested_ligand_mask"]

            permute_pred_indices, permuted_aligned_pred_coord, log_dict = (
                permute_pred_to_optimize_pocket_aligned_rmsd(
                    pred_coord=pred_coord,
                    true_coord=label_full_dict["coordinate"],
                    true_coord_mask=label_full_dict["coordinate_mask"],
                    true_pocket_mask=pocket_mask,
                    true_ligand_mask=ligand_mask,
                    atom_entity_id=input_feature_dict["entity_mol_id"],
                    atom_asym_id=input_feature_dict["mol_id"],
                    mol_atom_index=input_feature_dict["mol_atom_index"],
                    use_center_rmsd=kwargs.get("use_center_rmsd", False),
                )
            )
            output_dict = {"coordinate": permuted_aligned_pred_coord}
            permute_label_indices = []

        else:
            """Optimize the chain assignment on all chains"""
            output_dict, log_dict, permute_pred_indices, permute_label_indices = (
                correct_symmetric_chains(
                    pred_dict={**input_feature_dict, "coordinate": pred_coord},
                    label_full_dict=label_full_dict,
                    max_num_chains=max_num_chains,
                    permute_label=permute_label,
                    **kwargs,
                )
            )

    except Exception as e:
        error_message = f"{e}:\n{traceback.format_exc()}"
        logger.warning(error_message)
        save_permutation_error(
            data={
                "error_message": error_message,
                "pred_dict": {**input_feature_dict, "coordinate": pred_coord},
                "label_full_dict": label_full_dict,
                "max_num_chains": max_num_chains,
                "permute_label": permute_label,
                "dataset_name": input_feature_dict.get("dataset_name", None),
                "pdb_id": input_feature_dict.get("pdb_id", None),
            },
            error_dir=error_dir,
        )
        output_dict, log_dict, permute_pred_indices, permute_label_indices = (
            {},
            {},
            [],
            [],
        )

    return output_dict, log_dict, permute_pred_indices, permute_label_indices