DockFormerPP / dockformerpp /utils /script_utils.py
bshor's picture
add code
0fdcb79
raw
history blame
5.94 kB
import json
import logging
import os
import re
import time
from typing import List, Tuple
import numpy
import torch
from rdkit import Chem
from dockformerpp.model.model import AlphaFold
from dockformerpp.utils import residue_constants, protein
from dockformerpp.utils.consts import POSSIBLE_ATOM_TYPES, POSSIBLE_BOND_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def count_models_to_evaluate(model_checkpoint_path):
model_count = 0
if model_checkpoint_path:
model_count += len(model_checkpoint_path.split(","))
return model_count
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
# Function to get the latest checkpoint
def get_latest_checkpoint(checkpoint_dir):
if not os.path.exists(checkpoint_dir):
return None
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
if not checkpoints:
return None
latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x)))
return os.path.join(checkpoint_dir, latest_checkpoint)
def load_models_from_command_line(config, model_device, model_checkpoint_path, output_dir):
# Create the output directory
multiple_model_mode = count_models_to_evaluate(model_checkpoint_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if model_checkpoint_path:
for path in model_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = get_model_basename(path)
assert os.path.isfile(path), f"Model checkpoint not found at {path}"
ckpt_path = path
d = torch.load(ckpt_path)
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(model_device)
logger.info(
f"Loaded Model parameters at {path}..."
)
output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
yield model, output_directory
if not model_checkpoint_path:
raise ValueError("model_checkpoint_path must be specified.")
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs
def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""
Write dictionary of one or more run step times to a file
"""
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(timing_dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
def run_model(model, batch, tag, output_dir):
with torch.no_grad():
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}")
update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json"))
return out
def get_molecule_from_output(atoms_atype: List[int], atom_chiralities: List[int], atom_charges: List[int],
bonds: List[Tuple[int, int, int]], atom_positions: List[Tuple[float, float, float]]):
mol = Chem.RWMol()
assert len(atoms_atype) == len(atom_chiralities) == len(atom_charges) == len(atom_positions)
for atype_idx, chirality_idx, charge_idx in zip(atoms_atype, atom_chiralities, atom_charges):
new_atom = Chem.Atom(POSSIBLE_ATOM_TYPES[atype_idx])
new_atom.SetChiralTag(POSSIBLE_CHIRALITIES[chirality_idx])
new_atom.SetFormalCharge(POSSIBLE_CHARGES[charge_idx])
mol.AddAtom(new_atom)
# Add bonds
for bond in bonds:
atom1, atom2, bond_type_idx = bond
bond_type = POSSIBLE_BOND_TYPES[bond_type_idx]
mol.AddBond(int(atom1), int(atom2), bond_type)
# Set atom positions
conf = Chem.Conformer(len(atoms_atype))
for i, pos in enumerate(atom_positions.astype(float)):
conf.SetAtomPosition(i, pos)
mol.AddConformer(conf)
return mol
def save_output_structure(aatype, residue_index, chain_index, plddt, final_atom_protein_positions, final_atom_mask,
output_path):
plddt_b_factors = numpy.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
unrelaxed_protein = protein.from_prediction(
aatype=aatype,
residue_index=residue_index,
chain_index=chain_index,
atom_mask=final_atom_mask,
atom_positions=final_atom_protein_positions,
b_factors=plddt_b_factors,
remove_leading_feature_dimension=False,
)
with open(output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
print("Output written to", output_path)