Spaces:
Sleeping
Sleeping
import multiprocessing | |
import os | |
from argparse import ArgumentParser | |
from pathlib import Path | |
from typing import Optional | |
import rootutils | |
import torch | |
from loguru import logger | |
from pinder.core import PinderSystem, get_index | |
from pinder.core.loader.geodata import PairedPDB, structure2tensor | |
from pinder.core.loader.structure import Structure | |
from tqdm.auto import tqdm | |
# setup root dir and pythonpath | |
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
try: | |
from torch_cluster import knn_graph | |
torch_cluster_installed = True | |
except ImportError: | |
logger.warning( | |
"torch-cluster is not installed!" | |
"Please install the appropriate library for your pytorch installation." | |
"See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." | |
) | |
torch_cluster_installed = False | |
def create_lr_files(system_id: str, apo_complex_path: str, save_path: str): | |
apo_r_path = os.path.join(save_path, f"apo_r_{system_id}.pdb") | |
apo_l_path = os.path.join(save_path, f"apo_l_{system_id}.pdb") | |
native_path = apo_complex_path.with_name(apo_complex_path.stem + f"{system_id}.pdb") | |
with open(native_path) as infile, open(apo_r_path, "w") as output_r, open( | |
apo_l_path, "w" | |
) as output_l: | |
for line in infile: | |
# Check if the line is an ATOM or HETATM line and has a chain ID at position 21 | |
if line.startswith("ATOM") or line.startswith("HETATM"): | |
chain_id = line[21] | |
if chain_id == "R": | |
output_r.write(line) | |
elif chain_id == "L": | |
output_l.write(line) | |
else: | |
# Write other lines (e.g., HEADER, REMARK) to both files | |
output_r.write(line) | |
output_l.write(line) | |
return apo_r_path, apo_l_path | |
class CropPairedPDB(PairedPDB): | |
def from_crop_system( | |
cls, | |
system_id: str, | |
root: str = "./data/", | |
k: int = 10, | |
add_edges: bool = True, | |
predicted_structures: bool = True, | |
split: str = "train", | |
) -> None: | |
system = PinderSystem(system_id) | |
# Create directories if they do not exist | |
for subdir in ["apo", "holo", "predicted"]: | |
os.makedirs(Path(root) / "raw" / subdir / split, exist_ok=True) | |
try: | |
holo_complex, apo_complex, pred_complex = system.create_masked_bound_unbound_complexes( | |
renumber_residues=True | |
) | |
for complex_type, complex_obj in zip( | |
["apo", "holo", "predicted"], [apo_complex, holo_complex, pred_complex] | |
): | |
complex_obj.to_pdb( | |
Path(root) / "raw" / complex_type / split / f"{system_id}_complex.pdb" | |
) | |
except Exception as e: | |
logger.error(f"Error in writing PDB files: {e}, {system_id}") | |
return None | |
if predicted_structures: | |
apo_complex = pred_complex | |
save_path = os.path.join(root, "processed", "predicted", split) | |
else: | |
save_path = os.path.join(root, "processed", "apo", split) | |
# create the directory if it does not exist | |
os.makedirs(save_path, exist_ok=True) | |
graph = cls.from_structure_pair( | |
holo_complex=holo_complex, | |
apo_complex=apo_complex, | |
add_edges=add_edges, | |
k=k, | |
) | |
torch.save(graph, os.path.join(save_path, f"{system_id}.pt")) | |
def from_structure_pair( | |
cls, | |
holo_complex: Structure, | |
apo_complex: Structure, | |
add_edges: bool = True, | |
k: int = 10, | |
) -> PairedPDB: | |
def get_structure_props(structure: Structure, start: int, end: Optional[int]): | |
calpha = structure.filter("atom_name", mask=["CA"]) | |
return structure2tensor( | |
atom_coordinates=structure.coords[start:end], | |
atom_types=structure.atom_array.atom_name[start:end], | |
element_types=structure.atom_array.element[start:end], | |
residue_coordinates=calpha.coords[start:end], | |
residue_types=calpha.atom_array.res_name[start:end], | |
residue_ids=calpha.atom_array.res_id[start:end], | |
) | |
graph = cls() | |
r_h = (holo_complex.dataframe["chain_id"] == "R").sum() | |
r_a = (apo_complex.dataframe["chain_id"] == "R").sum() | |
holo_r_props = get_structure_props(holo_complex, 0, r_h) | |
holo_l_props = get_structure_props(holo_complex, r_h, None) | |
apo_r_props = get_structure_props(apo_complex, 0, r_a) | |
apo_l_props = get_structure_props(apo_complex, r_a, None) | |
graph["ligand"].x = apo_l_props["atom_types"] | |
graph["ligand"].pos = apo_l_props["atom_coordinates"] | |
graph["receptor"].x = apo_r_props["atom_types"] | |
graph["receptor"].pos = apo_r_props["atom_coordinates"] | |
graph["ligand"].y = holo_l_props["atom_coordinates"] | |
graph["receptor"].y = holo_r_props["atom_coordinates"] | |
if add_edges and torch_cluster_installed: | |
graph["ligand", "ligand"].edge_index = knn_graph(graph["ligand"].pos, k=k) | |
graph["receptor", "receptor"].edge_index = knn_graph(graph["receptor"].pos, k=k) | |
return graph | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--n_jobs", type=int, default=20) | |
parser.add_argument("--k", type=int, default=10) | |
parser.add_argument("--predicted_structures", action="store_true") | |
parser.add_argument("--split", type=str, default="train") | |
args = parser.parse_args() | |
predicted_structures = args.predicted_structures | |
# get indices for train, validation, and test splits | |
indices = get_index() | |
if predicted_structures: | |
query = '(split == "{split}") and ((apo_R == False and apo_L == False) and (predicted_R==True and predicted_L==True))' | |
else: | |
query = '(split == "{split}") and (apo_R == True and apo_L == True)' | |
system_idx = indices.query(query.format(split=args.split)).reset_index(drop=True) | |
system_ids = system_idx.id.tolist() | |
def process_system_id(system_id: str): | |
graph = CropPairedPDB.from_crop_system( | |
system_id, | |
predicted_structures=predicted_structures, | |
k=args.k, | |
split=args.split, | |
) | |
return graph | |
with multiprocessing.Pool(args.n_jobs) as pool: | |
results = list(tqdm(pool.imap(process_system_id, system_ids), total=len(system_ids))) | |