Spaces:
Sleeping
Sleeping
File size: 5,939 Bytes
0fdcb79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import json
import logging
import os
import re
import time
from typing import List, Tuple
import numpy
import torch
from rdkit import Chem
from dockformerpp.model.model import AlphaFold
from dockformerpp.utils import residue_constants, protein
from dockformerpp.utils.consts import POSSIBLE_ATOM_TYPES, POSSIBLE_BOND_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def count_models_to_evaluate(model_checkpoint_path):
model_count = 0
if model_checkpoint_path:
model_count += len(model_checkpoint_path.split(","))
return model_count
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
# Function to get the latest checkpoint
def get_latest_checkpoint(checkpoint_dir):
if not os.path.exists(checkpoint_dir):
return None
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
if not checkpoints:
return None
latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x)))
return os.path.join(checkpoint_dir, latest_checkpoint)
def load_models_from_command_line(config, model_device, model_checkpoint_path, output_dir):
# Create the output directory
multiple_model_mode = count_models_to_evaluate(model_checkpoint_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if model_checkpoint_path:
for path in model_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = get_model_basename(path)
assert os.path.isfile(path), f"Model checkpoint not found at {path}"
ckpt_path = path
d = torch.load(ckpt_path)
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(model_device)
logger.info(
f"Loaded Model parameters at {path}..."
)
output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
yield model, output_directory
if not model_checkpoint_path:
raise ValueError("model_checkpoint_path must be specified.")
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs
def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""
Write dictionary of one or more run step times to a file
"""
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(timing_dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
def run_model(model, batch, tag, output_dir):
with torch.no_grad():
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}")
update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json"))
return out
def get_molecule_from_output(atoms_atype: List[int], atom_chiralities: List[int], atom_charges: List[int],
bonds: List[Tuple[int, int, int]], atom_positions: List[Tuple[float, float, float]]):
mol = Chem.RWMol()
assert len(atoms_atype) == len(atom_chiralities) == len(atom_charges) == len(atom_positions)
for atype_idx, chirality_idx, charge_idx in zip(atoms_atype, atom_chiralities, atom_charges):
new_atom = Chem.Atom(POSSIBLE_ATOM_TYPES[atype_idx])
new_atom.SetChiralTag(POSSIBLE_CHIRALITIES[chirality_idx])
new_atom.SetFormalCharge(POSSIBLE_CHARGES[charge_idx])
mol.AddAtom(new_atom)
# Add bonds
for bond in bonds:
atom1, atom2, bond_type_idx = bond
bond_type = POSSIBLE_BOND_TYPES[bond_type_idx]
mol.AddBond(int(atom1), int(atom2), bond_type)
# Set atom positions
conf = Chem.Conformer(len(atoms_atype))
for i, pos in enumerate(atom_positions.astype(float)):
conf.SetAtomPosition(i, pos)
mol.AddConformer(conf)
return mol
def save_output_structure(aatype, residue_index, chain_index, plddt, final_atom_protein_positions, final_atom_mask,
output_path):
plddt_b_factors = numpy.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
unrelaxed_protein = protein.from_prediction(
aatype=aatype,
residue_index=residue_index,
chain_index=chain_index,
atom_mask=final_atom_mask,
atom_positions=final_atom_protein_positions,
b_factors=plddt_b_factors,
remove_leading_feature_dimension=False,
)
with open(output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
print("Output written to", output_path)
|