File size: 5,939 Bytes
0fdcb79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import json
import logging
import os
import re
import time
from typing import List, Tuple

import numpy
import torch
from rdkit import Chem

from dockformerpp.model.model import AlphaFold
from dockformerpp.utils import residue_constants, protein
from dockformerpp.utils.consts import POSSIBLE_ATOM_TYPES, POSSIBLE_BOND_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES

logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)


def count_models_to_evaluate(model_checkpoint_path):
    model_count = 0
    if model_checkpoint_path:
        model_count += len(model_checkpoint_path.split(","))
    return model_count


def get_model_basename(model_path):
    return os.path.splitext(
                os.path.basename(
                    os.path.normpath(model_path)
                )
            )[0]


def make_output_directory(output_dir, model_name, multiple_model_mode):
    if multiple_model_mode:
        prediction_dir = os.path.join(output_dir, "predictions", model_name)
    else:
        prediction_dir = os.path.join(output_dir, "predictions")
    os.makedirs(prediction_dir, exist_ok=True)
    return prediction_dir


# Function to get the latest checkpoint
def get_latest_checkpoint(checkpoint_dir):
    if not os.path.exists(checkpoint_dir):
        return None
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
    if not checkpoints:
        return None
    latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x)))
    return os.path.join(checkpoint_dir, latest_checkpoint)


def load_models_from_command_line(config, model_device, model_checkpoint_path, output_dir):
    # Create the output directory

    multiple_model_mode = count_models_to_evaluate(model_checkpoint_path) > 1
    if multiple_model_mode:
        logger.info(f"evaluating multiple models")

    if model_checkpoint_path:
        for path in model_checkpoint_path.split(","):
            model = AlphaFold(config)
            model = model.eval()
            checkpoint_basename = get_model_basename(path)
            assert os.path.isfile(path), f"Model checkpoint not found at {path}"
            ckpt_path = path
            d = torch.load(ckpt_path)

            if "ema" in d:
                # The public weights have had this done to them already
                d = d["ema"]["params"]
            model.load_state_dict(d)


            model = model.to(model_device)
            logger.info(
                f"Loaded Model parameters at {path}..."
            )
            output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
            yield model, output_directory

    if not model_checkpoint_path:
        raise ValueError("model_checkpoint_path must be specified.")


def parse_fasta(data):
    data = re.sub('>$', '', data, flags=re.M)
    lines = [
        l.replace('\n', '')
        for prot in data.split('>') for l in prot.strip().split('\n', 1)
    ][1:]
    tags, seqs = lines[::2], lines[1::2]

    tags = [re.split('\W| \|', t)[0] for t in tags]

    return tags, seqs


def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")):
    """
    Write dictionary of one or more run step times to a file
    """
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            try:
                timings = json.load(f)
            except json.JSONDecodeError:
                logger.info(f"Overwriting non-standard JSON in {output_file}.")
                timings = {}
    else:
        timings = {}
    timings.update(timing_dict)
    with open(output_file, "w") as f:
        json.dump(timings, f)
    return output_file


def run_model(model, batch, tag, output_dir):
    with torch.no_grad():
        logger.info(f"Running inference for {tag}...")
        t = time.perf_counter()
        out = model(batch)
        inference_time = time.perf_counter() - t
        logger.info(f"Inference time: {inference_time}")
        update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json"))

    return out


def get_molecule_from_output(atoms_atype: List[int], atom_chiralities: List[int], atom_charges: List[int],
                             bonds: List[Tuple[int, int, int]], atom_positions: List[Tuple[float, float, float]]):
    mol = Chem.RWMol()

    assert len(atoms_atype) == len(atom_chiralities) == len(atom_charges) == len(atom_positions)
    for atype_idx, chirality_idx, charge_idx in zip(atoms_atype, atom_chiralities, atom_charges):
        new_atom = Chem.Atom(POSSIBLE_ATOM_TYPES[atype_idx])
        new_atom.SetChiralTag(POSSIBLE_CHIRALITIES[chirality_idx])
        new_atom.SetFormalCharge(POSSIBLE_CHARGES[charge_idx])

        mol.AddAtom(new_atom)

    # Add bonds
    for bond in bonds:
        atom1, atom2, bond_type_idx = bond
        bond_type = POSSIBLE_BOND_TYPES[bond_type_idx]
        mol.AddBond(int(atom1), int(atom2), bond_type)

    # Set atom positions
    conf = Chem.Conformer(len(atoms_atype))
    for i, pos in enumerate(atom_positions.astype(float)):
        conf.SetAtomPosition(i, pos)
    mol.AddConformer(conf)
    return mol


def save_output_structure(aatype, residue_index, chain_index, plddt, final_atom_protein_positions, final_atom_mask,
                          output_path):
    plddt_b_factors = numpy.repeat(
        plddt[..., None], residue_constants.atom_type_num, axis=-1
    )

    unrelaxed_protein = protein.from_prediction(
        aatype=aatype,
        residue_index=residue_index,
        chain_index=chain_index,
        atom_mask=final_atom_mask,
        atom_positions=final_atom_protein_positions,
        b_factors=plddt_b_factors,
        remove_leading_feature_dimension=False,
    )

    with open(output_path, 'w') as fp:
        fp.write(protein.to_pdb(unrelaxed_protein))

    print("Output written to", output_path)