from __future__ import annotations import time import json import gradio as gr from gradio_molecule3d import Molecule3D import torch from pinder.core import get_pinder_location get_pinder_location() from pytorch_lightning import LightningModule import torch import lightning.pytorch as pl import torch.nn.functional as F import torch.nn as nn import torchmetrics import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing from torch_geometric.nn import global_mean_pool from torch.nn import Sequential, Linear, BatchNorm1d, ReLU from torch_scatter import scatter from torch.nn import Module import pinder.core as pinder pinder.__version__ from torch_geometric.loader import DataLoader from pinder.core.loader.dataset import get_geo_loader from pinder.core import download_dataset from pinder.core import get_index from pinder.core import get_metadata from pathlib import Path import pandas as pd from pinder.core import PinderSystem import torch from pinder.core.loader.dataset import PPIDataset from pinder.core.loader.geodata import NodeRepresentation import pickle from pinder.core import get_index, PinderSystem from torch_geometric.data import HeteroData import os from enum import Enum import numpy as np import torch import lightning.pytorch as pl from numpy.typing import NDArray from torch_geometric.data import HeteroData from pinder.core.index.system import PinderSystem from pinder.core.loader.structure import Structure from pinder.core.utils import constants as pc from pinder.core.utils.log import setup_logger from pinder.core.index.system import _align_monomers_with_mask from pinder.core.loader.structure import Structure import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing from torch_geometric.nn import global_mean_pool from torch.nn import Sequential, Linear, BatchNorm1d, ReLU from torch_scatter import scatter from torch.nn import Module import time from torch_geometric.nn import global_max_pool import copy import inspect import warnings from typing import Optional, Tuple, Union import torch from torch import Tensor from torch_geometric.data import Data, Dataset, HeteroData from torch_geometric.data.feature_store import FeatureStore from torch_geometric.data.graph_store import GraphStore from torch_geometric.loader import ( LinkLoader, LinkNeighborLoader, NeighborLoader, NodeLoader, ) from torch_geometric.loader.dataloader import DataLoader from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes from torch_geometric.sampler import BaseSampler, NeighborSampler from torch_geometric.typing import InputEdges, InputNodes try: from lightning.pytorch import LightningDataModule as PLLightningDataModule no_pytorch_lightning = False except (ImportError, ModuleNotFoundError): PLLightningDataModule = object no_pytorch_lightning = True from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers.tensorboard import TensorBoardLogger from lightning.pytorch.callbacks.early_stopping import EarlyStopping from torch_geometric.data.lightning.datamodule import LightningDataset from pytorch_lightning.loggers.wandb import WandbLogger def get_system(system_id: str) -> PinderSystem: return PinderSystem(system_id) from Bio import PDB from Bio.PDB.PDBIO import PDBIO # To create dataset, we have used only PINDER datyaset with following steps as follows: log = setup_logger(__name__) try: from torch_cluster import knn_graph torch_cluster_installed = True except ImportError as e: log.warning( "torch-cluster is not installed!" "Please install the appropriate library for your pytorch installation." "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." ) torch_cluster_installed = False def structure2tensor( atom_coordinates: NDArray[np.double] | None = None, atom_types: NDArray[np.str_] | None = None, element_types: NDArray[np.str_] | None = None, residue_coordinates: NDArray[np.double] | None = None, residue_ids: NDArray[np.int_] | None = None, residue_types: NDArray[np.str_] | None = None, chain_ids: NDArray[np.str_] | None = None, dtype: torch.dtype = torch.float32, ) -> dict[str, torch.Tensor]: property_dict = {} if atom_types is not None: unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1 types_array_at = np.zeros((len(atom_types), 1)) for i, name in enumerate(atom_types): types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx) property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype) if element_types is not None: types_array_ele = np.zeros((len(element_types), 1)) for i, name in enumerate(element_types): types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"]) property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype) if residue_types is not None: unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1 types_array_res = np.zeros((len(residue_types), 1)) for i, name in enumerate(residue_types): types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx) property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype) if atom_coordinates is not None: property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype) if residue_coordinates is not None: property_dict["residue_coordinates"] = torch.tensor( residue_coordinates, dtype=dtype ) if residue_ids is not None: property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype) if chain_ids is not None: property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype) property_dict["chain_ids"][chain_ids == "L"] = 1 return property_dict class NodeRepresentation(Enum): Surface = "surface" Atom = "atom" Residue = "residue" class PairedPDB(HeteroData): # type: ignore @classmethod def from_tuple_system( cls, tupal: tuple = (Structure , Structure , Structure), add_edges: bool = True, k: int = 10, ) -> PairedPDB: return cls.from_structure_pair( holo=tupal[0], apo=tupal[1], add_edges=add_edges, k=k, ) @classmethod def from_structure_pair( cls, holo: Structure, apo: Structure, add_edges: bool = True, k: int = 10, ) -> PairedPDB: graph = cls() holo_calpha = holo.filter("atom_name", mask=["CA"]) apo_calpha = apo.filter("atom_name", mask=["CA"]) r_h = (holo.dataframe['chain_id'] == 'R').sum() r_a = (apo.dataframe['chain_id'] == 'R').sum() holo_r_props = structure2tensor( atom_coordinates=holo.coords[:r_h], atom_types=holo.atom_array.atom_name[:r_h], element_types=holo.atom_array.element[:r_h], residue_coordinates=holo_calpha.coords[:r_h], residue_types=holo_calpha.atom_array.res_name[:r_h], residue_ids=holo_calpha.atom_array.res_id[:r_h], ) holo_l_props = structure2tensor( atom_coordinates=holo.coords[r_h:], atom_types=holo.atom_array.atom_name[r_h:], element_types=holo.atom_array.element[r_h:], residue_coordinates=holo_calpha.coords[r_h:], residue_types=holo_calpha.atom_array.res_name[r_h:], residue_ids=holo_calpha.atom_array.res_id[r_h:], ) apo_r_props = structure2tensor( atom_coordinates=apo.coords[:r_a], atom_types=apo.atom_array.atom_name[:r_a], element_types=apo.atom_array.element[:r_a], residue_coordinates=apo_calpha.coords[:r_a], residue_types=apo_calpha.atom_array.res_name[:r_a], residue_ids=apo_calpha.atom_array.res_id[:r_a], ) apo_l_props = structure2tensor( atom_coordinates=apo.coords[r_a:], atom_types=apo.atom_array.atom_name[r_a:], element_types=apo.atom_array.element[r_a:], residue_coordinates=apo_calpha.coords[r_a:], residue_types=apo_calpha.atom_array.res_name[r_a:], residue_ids=apo_calpha.atom_array.res_id[r_a:], ) graph["ligand"].x = apo_l_props["atom_types"] graph["ligand"].pos = apo_l_props["atom_coordinates"] graph["receptor"].x = apo_r_props["atom_types"] graph["receptor"].pos = apo_r_props["atom_coordinates"] graph["ligand"].y = holo_l_props["atom_coordinates"] # graph["ligand"].pos = holo_l_props["atom_coordinates"] graph["receptor"].y = holo_r_props["atom_coordinates"] # graph["receptor"].pos = holo_r_props["atom_coordinates"] if add_edges and torch_cluster_installed: graph["ligand"].edge_index = knn_graph( graph["ligand"].pos, k=k ) graph["receptor"].edge_index = knn_graph( graph["receptor"].pos, k=k ) # graph["ligand"].edge_index = knn_graph( # graph["ligand"].pos, k=k # ) # graph["receptor"].edge_index = knn_graph( # graph["receptor"].pos, k=k # ) return graph index = get_index() train = index[index.split == "train"].copy() val = index[index.split == "val"].copy() test = index[index.split == "test"].copy() train_filtered = train[(train['apo_R'] == True) & (train['apo_L'] == True)].copy() val_filtered = val[(val['apo_R'] == True) & (val['apo_L'] == True)].copy() test_filtered = test[(test['apo_R'] == True) & (test['apo_L'] == True)].copy() train_apo = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( monomer_types=["apo"], renumber_residues=True ) for i in range(0, 10000)] train_new_apo11 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( monomer_types=["apo"], renumber_residues=True ) for i in range(10000,10908)] train_new_apo12 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( # monomer_types=["apo"], renumber_residues=True ) for i in range(10908,11816)] val_new_apo1 = [get_system(val_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( monomer_types=["apo"], renumber_residues=True ) for i in range(0,342)] test_new_apo1 = [get_system(test_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( monomer_types=["apo"], renumber_residues=True ) for i in range(0,342)] val_apo = val_new_apo1 + train_new_apo11 test_apo = test_new_apo1 + train_new_apo12 import pickle # with open("train_apo.pkl", "wb") as file: # pickle.dump(train_apo, file) # with open("val_apo.pkl", "wb") as file: # pickle.dump(val_apo, file) # with open("test_apo.pkl", "wb") as file: # pickle.dump(test_apo, file) # with open("train_apo.pkl", "rb") as file: # train_apo = pickle.load(file) # with open("val_apo.pkl", "rb") as file: # val_apo = pickle.load(file) # with open("test_apo.pkl", "rb") as file: # test_apo = pickle.load(file) # # %% train_geo = [PairedPDB.from_tuple_system(train_apo[i]) for i in range(0,len(train_apo))] val_geo = [PairedPDB.from_tuple_system(val_apo[i]) for i in range(0,len(val_apo))] test_geo = [PairedPDB.from_tuple_system(test_apo[i]) for i in range(0,len(test_apo))] # # %% # Train= [] # for i in range(0,len(train_geo)): # data = HeteroData() # data["ligand"].x = train_geo[i]["ligand"].x # data['ligand'].y = train_geo[i]["ligand"].y # data["ligand"].pos = train_geo[i]["ligand"].pos # data["ligand","ligand"].edge_index = train_geo[i]["ligand"] # data["receptor"].x = train_geo[i]["receptor"].x # data['receptor'].y = train_geo[i]["receptor"].y # data["receptor"].pos = train_geo[i]["receptor"].pos # data["receptor","receptor"].edge_index = train_geo[i]["receptor"] # #torch.save(data, f"./data/processed/train_sample_{i}.pt") # Train.append(data) from torch_geometric.data import HeteroData import torch_sparse from torch_geometric.edge_index import to_sparse_tensor import torch # Example of converting edge indices to SparseTensor and storing them in HeteroData Train1 = [] for i in range(len(train_geo)): data = HeteroData() # Define ligand node features data["ligand"].x = train_geo[i]["ligand"].x data["ligand"].y = train_geo[i]["ligand"].y data["ligand"].pos = train_geo[i]["ligand"].pos # Convert ligand edge index to SparseTensor ligand_edge_index = train_geo[i]["ligand"]["edge_index"] data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(train_geo[i]["ligand"].num_nodes,)*2) # Define receptor node features data["receptor"].x = train_geo[i]["receptor"].x data["receptor"].y = train_geo[i]["receptor"].y data["receptor"].pos = train_geo[i]["receptor"].pos # Convert receptor edge index to SparseTensor receptor_edge_index = train_geo[i]["receptor"]["edge_index"] data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(train_geo[i]["receptor"].num_nodes,)*2) Train1.append(data) # # %% # Val= [] # for i in range(0,len(val_geo)): # data = HeteroData() # data["ligand"].x = val_geo[i]["ligand"].x # data['ligand'].y = val_geo[i]["ligand"].y # data["ligand"].pos = val_geo[i]["ligand"].pos # data["ligand","ligand"].edge_index = val_geo[i]["ligand"] # data["receptor"].x = val_geo[i]["receptor"].x # data['receptor'].y = val_geo[i]["receptor"].y # data["receptor"].pos = val_geo[i]["receptor"].pos # data["receptor","receptor"].edge_index = val_geo[i]["receptor"] # #torch.save(data, f"./data/processed/val_sample_{i}.pt") # Val.append(data) Val1 = [] for i in range(len(val_geo)): data = HeteroData() # Define ligand node features data["ligand"].x = val_geo[i]["ligand"].x data["ligand"].y = val_geo[i]["ligand"].y data["ligand"].pos = val_geo[i]["ligand"].pos # Convert ligand edge index to SparseTensor ligand_edge_index = val_geo[i]["ligand"]["edge_index"] data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(val_geo[i]["ligand"].num_nodes,)*2) # Define receptor node features data["receptor"].x = val_geo[i]["receptor"].x data["receptor"].y = val_geo[i]["receptor"].y data["receptor"].pos = val_geo[i]["receptor"].pos # Convert receptor edge index to SparseTensor receptor_edge_index = val_geo[i]["receptor"]["edge_index"] data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(val_geo[i]["receptor"].num_nodes,)*2) Val1.append(data) # # %% # Test= [] # for i in range(0,len(test_geo)): # data = HeteroData() # data["ligand"].x = test_geo[i]["ligand"].x # data['ligand'].y = test_geo[i]["ligand"].y # data["ligand"].pos = test_geo[i]["ligand"].pos # data["ligand","ligand"].edge_index = test_geo[i]["ligand"] # data["receptor"].x = test_geo[i]["receptor"].x # data['receptor'].y = test_geo[i]["receptor"].y # data["receptor"].pos = test_geo[i]["receptor"].pos # data["receptor","receptor"].edge_index = test_geo[i]["receptor"] # #torch.save(data, f"./data/processed/test_sample_{i}.pt") # Test.append(data) Test1 = [] for i in range(len(test_geo)): data = HeteroData() # Define ligand node features data["ligand"].x = test_geo[i]["ligand"].x data["ligand"].y = test_geo[i]["ligand"].y data["ligand"].pos = test_geo[i]["ligand"].pos # Convert ligand edge index to SparseTensor ligand_edge_index = test_geo[i]["ligand"]["edge_index"] data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(test_geo[i]["ligand"].num_nodes,)*2) # Define receptor node features data["receptor"].x = test_geo[i]["receptor"].x data["receptor"].y = test_geo[i]["receptor"].y data["receptor"].pos = test_geo[i]["receptor"].pos # Convert receptor edge index to SparseTensor receptor_edge_index = test_geo[i]["receptor"]["edge_index"] data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(test_geo[i]["receptor"].num_nodes,)*2) Test1.append(data) # with open("Train.pkl", "wb") as file: # pickle.dump(Train, file) # with open("Val.pkl", "wb") as file: # pickle.dump(Val, file) # with open("Test.pkl", "wb") as file: # pickle.dump(Test, file) # with open("Train1.pkl", "rb") as file: # Train= pickle.load(file) # with open("Val.pkl", "rb") as file: # Val = pickle.load(file) # with open("Test.pkl", "rb") as file: # Test = pickle.load(file)