Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import time | |
import json | |
import gradio as gr | |
from gradio_molecule3d import Molecule3D | |
import torch | |
from pinder.core import get_pinder_location | |
get_pinder_location() | |
from pytorch_lightning import LightningModule | |
import torch | |
import lightning.pytorch as pl | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torchmetrics | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.nn import global_mean_pool | |
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
from torch_scatter import scatter | |
from torch.nn import Module | |
import pinder.core as pinder | |
pinder.__version__ | |
from torch_geometric.loader import DataLoader | |
from pinder.core.loader.dataset import get_geo_loader | |
from pinder.core import download_dataset | |
from pinder.core import get_index | |
from pinder.core import get_metadata | |
from pathlib import Path | |
import pandas as pd | |
from pinder.core import PinderSystem | |
import torch | |
from pinder.core.loader.dataset import PPIDataset | |
from pinder.core.loader.geodata import NodeRepresentation | |
import pickle | |
from pinder.core import get_index, PinderSystem | |
from torch_geometric.data import HeteroData | |
import os | |
from enum import Enum | |
import numpy as np | |
import torch | |
import lightning.pytorch as pl | |
from numpy.typing import NDArray | |
from torch_geometric.data import HeteroData | |
from pinder.core.index.system import PinderSystem | |
from pinder.core.loader.structure import Structure | |
from pinder.core.utils import constants as pc | |
from pinder.core.utils.log import setup_logger | |
from pinder.core.index.system import _align_monomers_with_mask | |
from pinder.core.loader.structure import Structure | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.nn import global_mean_pool | |
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
from torch_scatter import scatter | |
from torch.nn import Module | |
import time | |
from torch_geometric.nn import global_max_pool | |
import copy | |
import inspect | |
import warnings | |
from typing import Optional, Tuple, Union | |
import torch | |
from torch import Tensor | |
from torch_geometric.data import Data, Dataset, HeteroData | |
from torch_geometric.data.feature_store import FeatureStore | |
from torch_geometric.data.graph_store import GraphStore | |
from torch_geometric.loader import ( | |
LinkLoader, | |
LinkNeighborLoader, | |
NeighborLoader, | |
NodeLoader, | |
) | |
from torch_geometric.loader.dataloader import DataLoader | |
from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes | |
from torch_geometric.sampler import BaseSampler, NeighborSampler | |
from torch_geometric.typing import InputEdges, InputNodes | |
try: | |
from lightning.pytorch import LightningDataModule as PLLightningDataModule | |
no_pytorch_lightning = False | |
except (ImportError, ModuleNotFoundError): | |
PLLightningDataModule = object | |
no_pytorch_lightning = True | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger | |
from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
from torch_geometric.data.lightning.datamodule import LightningDataset | |
from pytorch_lightning.loggers.wandb import WandbLogger | |
def get_system(system_id: str) -> PinderSystem: | |
return PinderSystem(system_id) | |
from Bio import PDB | |
from Bio.PDB.PDBIO import PDBIO | |
# To create dataset, we have used only PINDER datyaset with following steps as follows: | |
log = setup_logger(__name__) | |
try: | |
from torch_cluster import knn_graph | |
torch_cluster_installed = True | |
except ImportError as e: | |
log.warning( | |
"torch-cluster is not installed!" | |
"Please install the appropriate library for your pytorch installation." | |
"See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." | |
) | |
torch_cluster_installed = False | |
def structure2tensor( | |
atom_coordinates: NDArray[np.double] | None = None, | |
atom_types: NDArray[np.str_] | None = None, | |
element_types: NDArray[np.str_] | None = None, | |
residue_coordinates: NDArray[np.double] | None = None, | |
residue_ids: NDArray[np.int_] | None = None, | |
residue_types: NDArray[np.str_] | None = None, | |
chain_ids: NDArray[np.str_] | None = None, | |
dtype: torch.dtype = torch.float32, | |
) -> dict[str, torch.Tensor]: | |
property_dict = {} | |
if atom_types is not None: | |
unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1 | |
types_array_at = np.zeros((len(atom_types), 1)) | |
for i, name in enumerate(atom_types): | |
types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx) | |
property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype) | |
if element_types is not None: | |
types_array_ele = np.zeros((len(element_types), 1)) | |
for i, name in enumerate(element_types): | |
types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"]) | |
property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype) | |
if residue_types is not None: | |
unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1 | |
types_array_res = np.zeros((len(residue_types), 1)) | |
for i, name in enumerate(residue_types): | |
types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx) | |
property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype) | |
if atom_coordinates is not None: | |
property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype) | |
if residue_coordinates is not None: | |
property_dict["residue_coordinates"] = torch.tensor( | |
residue_coordinates, dtype=dtype | |
) | |
if residue_ids is not None: | |
property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype) | |
if chain_ids is not None: | |
property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype) | |
property_dict["chain_ids"][chain_ids == "L"] = 1 | |
return property_dict | |
class NodeRepresentation(Enum): | |
Surface = "surface" | |
Atom = "atom" | |
Residue = "residue" | |
class PairedPDB(HeteroData): # type: ignore | |
def from_tuple_system( | |
cls, | |
tupal: tuple = (Structure , Structure , Structure), | |
add_edges: bool = True, | |
k: int = 10, | |
) -> PairedPDB: | |
return cls.from_structure_pair( | |
holo=tupal[0], | |
apo=tupal[1], | |
add_edges=add_edges, | |
k=k, | |
) | |
def from_structure_pair( | |
cls, | |
holo: Structure, | |
apo: Structure, | |
add_edges: bool = True, | |
k: int = 10, | |
) -> PairedPDB: | |
graph = cls() | |
holo_calpha = holo.filter("atom_name", mask=["CA"]) | |
apo_calpha = apo.filter("atom_name", mask=["CA"]) | |
r_h = (holo.dataframe['chain_id'] == 'R').sum() | |
r_a = (apo.dataframe['chain_id'] == 'R').sum() | |
holo_r_props = structure2tensor( | |
atom_coordinates=holo.coords[:r_h], | |
atom_types=holo.atom_array.atom_name[:r_h], | |
element_types=holo.atom_array.element[:r_h], | |
residue_coordinates=holo_calpha.coords[:r_h], | |
residue_types=holo_calpha.atom_array.res_name[:r_h], | |
residue_ids=holo_calpha.atom_array.res_id[:r_h], | |
) | |
holo_l_props = structure2tensor( | |
atom_coordinates=holo.coords[r_h:], | |
atom_types=holo.atom_array.atom_name[r_h:], | |
element_types=holo.atom_array.element[r_h:], | |
residue_coordinates=holo_calpha.coords[r_h:], | |
residue_types=holo_calpha.atom_array.res_name[r_h:], | |
residue_ids=holo_calpha.atom_array.res_id[r_h:], | |
) | |
apo_r_props = structure2tensor( | |
atom_coordinates=apo.coords[:r_a], | |
atom_types=apo.atom_array.atom_name[:r_a], | |
element_types=apo.atom_array.element[:r_a], | |
residue_coordinates=apo_calpha.coords[:r_a], | |
residue_types=apo_calpha.atom_array.res_name[:r_a], | |
residue_ids=apo_calpha.atom_array.res_id[:r_a], | |
) | |
apo_l_props = structure2tensor( | |
atom_coordinates=apo.coords[r_a:], | |
atom_types=apo.atom_array.atom_name[r_a:], | |
element_types=apo.atom_array.element[r_a:], | |
residue_coordinates=apo_calpha.coords[r_a:], | |
residue_types=apo_calpha.atom_array.res_name[r_a:], | |
residue_ids=apo_calpha.atom_array.res_id[r_a:], | |
) | |
graph["ligand"].x = apo_l_props["atom_types"] | |
graph["ligand"].pos = apo_l_props["atom_coordinates"] | |
graph["receptor"].x = apo_r_props["atom_types"] | |
graph["receptor"].pos = apo_r_props["atom_coordinates"] | |
graph["ligand"].y = holo_l_props["atom_coordinates"] | |
# graph["ligand"].pos = holo_l_props["atom_coordinates"] | |
graph["receptor"].y = holo_r_props["atom_coordinates"] | |
# graph["receptor"].pos = holo_r_props["atom_coordinates"] | |
if add_edges and torch_cluster_installed: | |
graph["ligand"].edge_index = knn_graph( | |
graph["ligand"].pos, k=k | |
) | |
graph["receptor"].edge_index = knn_graph( | |
graph["receptor"].pos, k=k | |
) | |
# graph["ligand"].edge_index = knn_graph( | |
# graph["ligand"].pos, k=k | |
# ) | |
# graph["receptor"].edge_index = knn_graph( | |
# graph["receptor"].pos, k=k | |
# ) | |
return graph | |
index = get_index() | |
train = index[index.split == "train"].copy() | |
val = index[index.split == "val"].copy() | |
test = index[index.split == "test"].copy() | |
train_filtered = train[(train['apo_R'] == True) & (train['apo_L'] == True)].copy() | |
val_filtered = val[(val['apo_R'] == True) & (val['apo_L'] == True)].copy() | |
test_filtered = test[(test['apo_R'] == True) & (test['apo_L'] == True)].copy() | |
train_apo = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
monomer_types=["apo"], renumber_residues=True | |
) for i in range(0, 10000)] | |
train_new_apo11 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
monomer_types=["apo"], renumber_residues=True | |
) for i in range(10000,10908)] | |
train_new_apo12 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
# monomer_types=["apo"], renumber_residues=True | |
) for i in range(10908,11816)] | |
val_new_apo1 = [get_system(val_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
monomer_types=["apo"], renumber_residues=True | |
) for i in range(0,342)] | |
test_new_apo1 = [get_system(test_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
monomer_types=["apo"], renumber_residues=True | |
) for i in range(0,342)] | |
val_apo = val_new_apo1 + train_new_apo11 | |
test_apo = test_new_apo1 + train_new_apo12 | |
import pickle | |
# with open("train_apo.pkl", "wb") as file: | |
# pickle.dump(train_apo, file) | |
# with open("val_apo.pkl", "wb") as file: | |
# pickle.dump(val_apo, file) | |
# with open("test_apo.pkl", "wb") as file: | |
# pickle.dump(test_apo, file) | |
# with open("train_apo.pkl", "rb") as file: | |
# train_apo = pickle.load(file) | |
# with open("val_apo.pkl", "rb") as file: | |
# val_apo = pickle.load(file) | |
# with open("test_apo.pkl", "rb") as file: | |
# test_apo = pickle.load(file) | |
# # %% | |
train_geo = [PairedPDB.from_tuple_system(train_apo[i]) for i in range(0,len(train_apo))] | |
val_geo = [PairedPDB.from_tuple_system(val_apo[i]) for i in range(0,len(val_apo))] | |
test_geo = [PairedPDB.from_tuple_system(test_apo[i]) for i in range(0,len(test_apo))] | |
# # %% | |
# Train= [] | |
# for i in range(0,len(train_geo)): | |
# data = HeteroData() | |
# data["ligand"].x = train_geo[i]["ligand"].x | |
# data['ligand'].y = train_geo[i]["ligand"].y | |
# data["ligand"].pos = train_geo[i]["ligand"].pos | |
# data["ligand","ligand"].edge_index = train_geo[i]["ligand"] | |
# data["receptor"].x = train_geo[i]["receptor"].x | |
# data['receptor'].y = train_geo[i]["receptor"].y | |
# data["receptor"].pos = train_geo[i]["receptor"].pos | |
# data["receptor","receptor"].edge_index = train_geo[i]["receptor"] | |
# #torch.save(data, f"./data/processed/train_sample_{i}.pt") | |
# Train.append(data) | |
from torch_geometric.data import HeteroData | |
import torch_sparse | |
from torch_geometric.edge_index import to_sparse_tensor | |
import torch | |
# Example of converting edge indices to SparseTensor and storing them in HeteroData | |
Train1 = [] | |
for i in range(len(train_geo)): | |
data = HeteroData() | |
# Define ligand node features | |
data["ligand"].x = train_geo[i]["ligand"].x | |
data["ligand"].y = train_geo[i]["ligand"].y | |
data["ligand"].pos = train_geo[i]["ligand"].pos | |
# Convert ligand edge index to SparseTensor | |
ligand_edge_index = train_geo[i]["ligand"]["edge_index"] | |
data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(train_geo[i]["ligand"].num_nodes,)*2) | |
# Define receptor node features | |
data["receptor"].x = train_geo[i]["receptor"].x | |
data["receptor"].y = train_geo[i]["receptor"].y | |
data["receptor"].pos = train_geo[i]["receptor"].pos | |
# Convert receptor edge index to SparseTensor | |
receptor_edge_index = train_geo[i]["receptor"]["edge_index"] | |
data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(train_geo[i]["receptor"].num_nodes,)*2) | |
Train1.append(data) | |
# # %% | |
# Val= [] | |
# for i in range(0,len(val_geo)): | |
# data = HeteroData() | |
# data["ligand"].x = val_geo[i]["ligand"].x | |
# data['ligand'].y = val_geo[i]["ligand"].y | |
# data["ligand"].pos = val_geo[i]["ligand"].pos | |
# data["ligand","ligand"].edge_index = val_geo[i]["ligand"] | |
# data["receptor"].x = val_geo[i]["receptor"].x | |
# data['receptor'].y = val_geo[i]["receptor"].y | |
# data["receptor"].pos = val_geo[i]["receptor"].pos | |
# data["receptor","receptor"].edge_index = val_geo[i]["receptor"] | |
# #torch.save(data, f"./data/processed/val_sample_{i}.pt") | |
# Val.append(data) | |
Val1 = [] | |
for i in range(len(val_geo)): | |
data = HeteroData() | |
# Define ligand node features | |
data["ligand"].x = val_geo[i]["ligand"].x | |
data["ligand"].y = val_geo[i]["ligand"].y | |
data["ligand"].pos = val_geo[i]["ligand"].pos | |
# Convert ligand edge index to SparseTensor | |
ligand_edge_index = val_geo[i]["ligand"]["edge_index"] | |
data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(val_geo[i]["ligand"].num_nodes,)*2) | |
# Define receptor node features | |
data["receptor"].x = val_geo[i]["receptor"].x | |
data["receptor"].y = val_geo[i]["receptor"].y | |
data["receptor"].pos = val_geo[i]["receptor"].pos | |
# Convert receptor edge index to SparseTensor | |
receptor_edge_index = val_geo[i]["receptor"]["edge_index"] | |
data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(val_geo[i]["receptor"].num_nodes,)*2) | |
Val1.append(data) | |
# # %% | |
# Test= [] | |
# for i in range(0,len(test_geo)): | |
# data = HeteroData() | |
# data["ligand"].x = test_geo[i]["ligand"].x | |
# data['ligand'].y = test_geo[i]["ligand"].y | |
# data["ligand"].pos = test_geo[i]["ligand"].pos | |
# data["ligand","ligand"].edge_index = test_geo[i]["ligand"] | |
# data["receptor"].x = test_geo[i]["receptor"].x | |
# data['receptor'].y = test_geo[i]["receptor"].y | |
# data["receptor"].pos = test_geo[i]["receptor"].pos | |
# data["receptor","receptor"].edge_index = test_geo[i]["receptor"] | |
# #torch.save(data, f"./data/processed/test_sample_{i}.pt") | |
# Test.append(data) | |
Test1 = [] | |
for i in range(len(test_geo)): | |
data = HeteroData() | |
# Define ligand node features | |
data["ligand"].x = test_geo[i]["ligand"].x | |
data["ligand"].y = test_geo[i]["ligand"].y | |
data["ligand"].pos = test_geo[i]["ligand"].pos | |
# Convert ligand edge index to SparseTensor | |
ligand_edge_index = test_geo[i]["ligand"]["edge_index"] | |
data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(test_geo[i]["ligand"].num_nodes,)*2) | |
# Define receptor node features | |
data["receptor"].x = test_geo[i]["receptor"].x | |
data["receptor"].y = test_geo[i]["receptor"].y | |
data["receptor"].pos = test_geo[i]["receptor"].pos | |
# Convert receptor edge index to SparseTensor | |
receptor_edge_index = test_geo[i]["receptor"]["edge_index"] | |
data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(test_geo[i]["receptor"].num_nodes,)*2) | |
Test1.append(data) | |
# with open("Train.pkl", "wb") as file: | |
# pickle.dump(Train, file) | |
# with open("Val.pkl", "wb") as file: | |
# pickle.dump(Val, file) | |
# with open("Test.pkl", "wb") as file: | |
# pickle.dump(Test, file) | |
# with open("Train1.pkl", "rb") as file: | |
# Train= pickle.load(file) | |
# with open("Val.pkl", "rb") as file: | |
# Val = pickle.load(file) | |
# with open("Test.pkl", "rb") as file: | |
# Test = pickle.load(file) |