Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import os | |
import re | |
import time | |
from typing import List, Tuple | |
import numpy | |
import torch | |
from rdkit import Chem | |
from dockformerpp.model.model import AlphaFold | |
from dockformerpp.utils import residue_constants, protein | |
from dockformerpp.utils.consts import POSSIBLE_ATOM_TYPES, POSSIBLE_BOND_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES | |
logging.basicConfig() | |
logger = logging.getLogger(__file__) | |
logger.setLevel(level=logging.INFO) | |
def count_models_to_evaluate(model_checkpoint_path): | |
model_count = 0 | |
if model_checkpoint_path: | |
model_count += len(model_checkpoint_path.split(",")) | |
return model_count | |
def get_model_basename(model_path): | |
return os.path.splitext( | |
os.path.basename( | |
os.path.normpath(model_path) | |
) | |
)[0] | |
def make_output_directory(output_dir, model_name, multiple_model_mode): | |
if multiple_model_mode: | |
prediction_dir = os.path.join(output_dir, "predictions", model_name) | |
else: | |
prediction_dir = os.path.join(output_dir, "predictions") | |
os.makedirs(prediction_dir, exist_ok=True) | |
return prediction_dir | |
# Function to get the latest checkpoint | |
def get_latest_checkpoint(checkpoint_dir): | |
if not os.path.exists(checkpoint_dir): | |
return None | |
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')] | |
if not checkpoints: | |
return None | |
latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x))) | |
return os.path.join(checkpoint_dir, latest_checkpoint) | |
def load_models_from_command_line(config, model_device, model_checkpoint_path, output_dir): | |
# Create the output directory | |
multiple_model_mode = count_models_to_evaluate(model_checkpoint_path) > 1 | |
if multiple_model_mode: | |
logger.info(f"evaluating multiple models") | |
if model_checkpoint_path: | |
for path in model_checkpoint_path.split(","): | |
model = AlphaFold(config) | |
model = model.eval() | |
checkpoint_basename = get_model_basename(path) | |
assert os.path.isfile(path), f"Model checkpoint not found at {path}" | |
ckpt_path = path | |
d = torch.load(ckpt_path) | |
if "ema" in d: | |
# The public weights have had this done to them already | |
d = d["ema"]["params"] | |
model.load_state_dict(d) | |
model = model.to(model_device) | |
logger.info( | |
f"Loaded Model parameters at {path}..." | |
) | |
output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode) | |
yield model, output_directory | |
if not model_checkpoint_path: | |
raise ValueError("model_checkpoint_path must be specified.") | |
def parse_fasta(data): | |
data = re.sub('>$', '', data, flags=re.M) | |
lines = [ | |
l.replace('\n', '') | |
for prot in data.split('>') for l in prot.strip().split('\n', 1) | |
][1:] | |
tags, seqs = lines[::2], lines[1::2] | |
tags = [re.split('\W| \|', t)[0] for t in tags] | |
return tags, seqs | |
def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")): | |
""" | |
Write dictionary of one or more run step times to a file | |
""" | |
if os.path.exists(output_file): | |
with open(output_file, "r") as f: | |
try: | |
timings = json.load(f) | |
except json.JSONDecodeError: | |
logger.info(f"Overwriting non-standard JSON in {output_file}.") | |
timings = {} | |
else: | |
timings = {} | |
timings.update(timing_dict) | |
with open(output_file, "w") as f: | |
json.dump(timings, f) | |
return output_file | |
def run_model(model, batch, tag, output_dir): | |
with torch.no_grad(): | |
logger.info(f"Running inference for {tag}...") | |
t = time.perf_counter() | |
out = model(batch) | |
inference_time = time.perf_counter() - t | |
logger.info(f"Inference time: {inference_time}") | |
update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json")) | |
return out | |
def get_molecule_from_output(atoms_atype: List[int], atom_chiralities: List[int], atom_charges: List[int], | |
bonds: List[Tuple[int, int, int]], atom_positions: List[Tuple[float, float, float]]): | |
mol = Chem.RWMol() | |
assert len(atoms_atype) == len(atom_chiralities) == len(atom_charges) == len(atom_positions) | |
for atype_idx, chirality_idx, charge_idx in zip(atoms_atype, atom_chiralities, atom_charges): | |
new_atom = Chem.Atom(POSSIBLE_ATOM_TYPES[atype_idx]) | |
new_atom.SetChiralTag(POSSIBLE_CHIRALITIES[chirality_idx]) | |
new_atom.SetFormalCharge(POSSIBLE_CHARGES[charge_idx]) | |
mol.AddAtom(new_atom) | |
# Add bonds | |
for bond in bonds: | |
atom1, atom2, bond_type_idx = bond | |
bond_type = POSSIBLE_BOND_TYPES[bond_type_idx] | |
mol.AddBond(int(atom1), int(atom2), bond_type) | |
# Set atom positions | |
conf = Chem.Conformer(len(atoms_atype)) | |
for i, pos in enumerate(atom_positions.astype(float)): | |
conf.SetAtomPosition(i, pos) | |
mol.AddConformer(conf) | |
return mol | |
def save_output_structure(aatype, residue_index, chain_index, plddt, final_atom_protein_positions, final_atom_mask, | |
output_path): | |
plddt_b_factors = numpy.repeat( | |
plddt[..., None], residue_constants.atom_type_num, axis=-1 | |
) | |
unrelaxed_protein = protein.from_prediction( | |
aatype=aatype, | |
residue_index=residue_index, | |
chain_index=chain_index, | |
atom_mask=final_atom_mask, | |
atom_positions=final_atom_protein_positions, | |
b_factors=plddt_b_factors, | |
remove_leading_feature_dimension=False, | |
) | |
with open(output_path, 'w') as fp: | |
fp.write(protein.to_pdb(unrelaxed_protein)) | |
print("Output written to", output_path) | |