Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import time | |
import json | |
import gradio as gr | |
from gradio_molecule3d import Molecule3D | |
import torch | |
from pinder.core import get_pinder_location | |
get_pinder_location() | |
from pytorch_lightning import LightningModule | |
import torch | |
import lightning.pytorch as pl | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torchmetrics | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.nn import global_mean_pool | |
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
from torch_scatter import scatter | |
from torch.nn import Module | |
import pinder.core as pinder | |
pinder.__version__ | |
from torch_geometric.loader import DataLoader | |
from pinder.core.loader.dataset import get_geo_loader | |
from pinder.core import download_dataset | |
from pinder.core import get_index | |
from pinder.core import get_metadata | |
from pathlib import Path | |
import pandas as pd | |
from pinder.core import PinderSystem | |
import torch | |
from pinder.core.loader.dataset import PPIDataset | |
from pinder.core.loader.geodata import NodeRepresentation | |
import pickle | |
from pinder.core import get_index, PinderSystem | |
from torch_geometric.data import HeteroData | |
import os | |
from enum import Enum | |
import numpy as np | |
import torch | |
import lightning.pytorch as pl | |
from numpy.typing import NDArray | |
from torch_geometric.data import HeteroData | |
from pinder.core.index.system import PinderSystem | |
from pinder.core.loader.structure import Structure | |
from pinder.core.utils import constants as pc | |
from pinder.core.utils.log import setup_logger | |
from pinder.core.index.system import _align_monomers_with_mask | |
from pinder.core.loader.structure import Structure | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.nn import global_mean_pool | |
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
from torch_scatter import scatter | |
from torch.nn import Module | |
import time | |
from torch_geometric.nn import global_max_pool | |
import copy | |
import inspect | |
import warnings | |
from typing import Optional, Tuple, Union | |
import torch | |
from torch import Tensor | |
from torch_geometric.data import Data, Dataset, HeteroData | |
from torch_geometric.data.feature_store import FeatureStore | |
from torch_geometric.data.graph_store import GraphStore | |
from torch_geometric.loader import ( | |
LinkLoader, | |
LinkNeighborLoader, | |
NeighborLoader, | |
NodeLoader, | |
) | |
from torch_geometric.loader.dataloader import DataLoader | |
from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes | |
from torch_geometric.sampler import BaseSampler, NeighborSampler | |
from torch_geometric.typing import InputEdges, InputNodes | |
try: | |
from lightning.pytorch import LightningDataModule as PLLightningDataModule | |
no_pytorch_lightning = False | |
except (ImportError, ModuleNotFoundError): | |
PLLightningDataModule = object | |
no_pytorch_lightning = True | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger | |
from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
from torch_geometric.data.lightning.datamodule import LightningDataset | |
from pytorch_lightning.loggers.wandb import WandbLogger | |
def get_system(system_id: str) -> PinderSystem: | |
return PinderSystem(system_id) | |
from Bio import PDB | |
from Bio.PDB.PDBIO import PDBIO | |
from pinder.core.structure.atoms import atom_array_from_pdb_file | |
from pathlib import Path | |
from pinder.eval.dockq.biotite_dockq import BiotiteDockQ | |
def extract_coordinates_from_pdb(filename, atom_name="CA"): | |
""" | |
Extracts coordinates for specific atoms from a PDB file and returns them as a list of tuples. | |
Each tuple contains (x, y, z) coordinates of the specified atom type. | |
Parameters: | |
filename (str): Path to the PDB file. | |
atom_name (str): The name of the atom to filter by (e.g., "CA" for alpha carbon). | |
Returns: | |
list of tuple: List of coordinates as (x, y, z) tuples for the specified atom. | |
""" | |
parser = PDB.PDBParser(QUIET=True) | |
structure = parser.get_structure("structure", filename) | |
coordinates = [] | |
# Loop through each model, chain, residue, and atom to collect coordinates of specified atom | |
for model in structure: | |
for chain in model: | |
for residue in chain: | |
for atom in residue: | |
# Filter for specific atom name | |
xyz = atom.coord # Coordinates are in a numpy array | |
coordinates.append([xyz[0], xyz[1], xyz[2]]) | |
return coordinates | |
log = setup_logger(__name__) | |
try: | |
from torch_cluster import knn_graph | |
torch_cluster_installed = True | |
except ImportError as e: | |
log.warning( | |
"torch-cluster is not installed!" | |
"Please install the appropriate library for your pytorch installation." | |
"See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." | |
) | |
torch_cluster_installed = False | |
def structure2tensor( | |
atom_coordinates: NDArray[np.double] | None = None, | |
atom_types: NDArray[np.str_] | None = None, | |
element_types: NDArray[np.str_] | None = None, | |
residue_coordinates: NDArray[np.double] | None = None, | |
residue_ids: NDArray[np.int_] | None = None, | |
residue_types: NDArray[np.str_] | None = None, | |
chain_ids: NDArray[np.str_] | None = None, | |
dtype: torch.dtype = torch.float32, | |
) -> dict[str, torch.Tensor]: | |
property_dict = {} | |
if atom_types is not None: | |
unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1 | |
types_array_at = np.zeros((len(atom_types), 1)) | |
for i, name in enumerate(atom_types): | |
types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx) | |
property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype) | |
if element_types is not None: | |
types_array_ele = np.zeros((len(element_types), 1)) | |
for i, name in enumerate(element_types): | |
types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"]) | |
property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype) | |
if residue_types is not None: | |
unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1 | |
types_array_res = np.zeros((len(residue_types), 1)) | |
for i, name in enumerate(residue_types): | |
types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx) | |
property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype) | |
if atom_coordinates is not None: | |
property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype) | |
if residue_coordinates is not None: | |
property_dict["residue_coordinates"] = torch.tensor( | |
residue_coordinates, dtype=dtype | |
) | |
if residue_ids is not None: | |
property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype) | |
if chain_ids is not None: | |
property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype) | |
property_dict["chain_ids"][chain_ids == "L"] = 1 | |
return property_dict | |
class NodeRepresentation(Enum): | |
Surface = "surface" | |
Atom = "atom" | |
Residue = "residue" | |
class PairedPDB(HeteroData): # type: ignore | |
def from_tuple_system( | |
cls, | |
tupal: tuple = (Structure , Structure , Structure), | |
add_edges: bool = True, | |
k: int = 10, | |
) -> PairedPDB: | |
return cls.from_structure_pair( | |
holo=tupal[0], | |
apo=tupal[1], | |
add_edges=add_edges, | |
k=k, | |
) | |
def from_structure_pair( | |
cls, | |
holo: Structure, | |
apo: Structure, | |
add_edges: bool = True, | |
k: int = 10, | |
) -> PairedPDB: | |
graph = cls() | |
holo_calpha = holo.filter("atom_name", mask=["CA"]) | |
apo_calpha = apo.filter("atom_name", mask=["CA"]) | |
r_h = (holo.dataframe['chain_id'] == 'R').sum() | |
r_a = (apo.dataframe['chain_id'] == 'R').sum() | |
holo_r_props = structure2tensor( | |
atom_coordinates=holo.coords[:r_h], | |
atom_types=holo.atom_array.atom_name[:r_h], | |
element_types=holo.atom_array.element[:r_h], | |
residue_coordinates=holo_calpha.coords[:r_h], | |
residue_types=holo_calpha.atom_array.res_name[:r_h], | |
residue_ids=holo_calpha.atom_array.res_id[:r_h], | |
) | |
holo_l_props = structure2tensor( | |
atom_coordinates=holo.coords[r_h:], | |
atom_types=holo.atom_array.atom_name[r_h:], | |
element_types=holo.atom_array.element[r_h:], | |
residue_coordinates=holo_calpha.coords[r_h:], | |
residue_types=holo_calpha.atom_array.res_name[r_h:], | |
residue_ids=holo_calpha.atom_array.res_id[r_h:], | |
) | |
apo_r_props = structure2tensor( | |
atom_coordinates=apo.coords[:r_a], | |
atom_types=apo.atom_array.atom_name[:r_a], | |
element_types=apo.atom_array.element[:r_a], | |
residue_coordinates=apo_calpha.coords[:r_a], | |
residue_types=apo_calpha.atom_array.res_name[:r_a], | |
residue_ids=apo_calpha.atom_array.res_id[:r_a], | |
) | |
apo_l_props = structure2tensor( | |
atom_coordinates=apo.coords[r_a:], | |
atom_types=apo.atom_array.atom_name[r_a:], | |
element_types=apo.atom_array.element[r_a:], | |
residue_coordinates=apo_calpha.coords[r_a:], | |
residue_types=apo_calpha.atom_array.res_name[r_a:], | |
residue_ids=apo_calpha.atom_array.res_id[r_a:], | |
) | |
graph["ligand"].x = apo_l_props["atom_types"] | |
graph["ligand"].pos = apo_l_props["atom_coordinates"] | |
graph["receptor"].x = apo_r_props["atom_types"] | |
graph["receptor"].pos = apo_r_props["atom_coordinates"] | |
graph["ligand"].y = holo_l_props["atom_coordinates"] | |
# graph["ligand"].pos = holo_l_props["atom_coordinates"] | |
graph["receptor"].y = holo_r_props["atom_coordinates"] | |
# graph["receptor"].pos = holo_r_props["atom_coordinates"] | |
if add_edges and torch_cluster_installed: | |
graph["ligand"].edge_index = knn_graph( | |
graph["ligand"].pos, k=k | |
) | |
graph["receptor"].edge_index = knn_graph( | |
graph["receptor"].pos, k=k | |
) | |
# graph["ligand"].edge_index = knn_graph( | |
# graph["ligand"].pos, k=k | |
# ) | |
# graph["receptor"].edge_index = knn_graph( | |
# graph["receptor"].pos, k=k | |
# ) | |
return graph | |
#create_graph takes inputs apo_ligand, apo_residue and paired holo as pdb3(ground truth). | |
def create_graph(pdb1, pdb2, k=5): | |
r""" | |
Create a heterogeneous graph from two PDB files, with the ligand and receptor | |
as separate nodes, and their respective features and edges. | |
Args: | |
pdb1 (str): PDB file path for ligand. | |
pdb2 (str): PDB file path for receptor. | |
coords3 (list): List of coordinates used for `y` values (e.g., binding affinity, etc.). | |
k (int): Number of nearest neighbors for constructing the knn graph. | |
Returns: | |
HeteroData: A PyG HeteroData object containing ligand and receptor data. | |
""" | |
# Extract coordinates from PDB files | |
coords1 = torch.tensor(extract_coordinates_from_pdb(pdb1),dtype=torch.float) | |
coords2 = torch.tensor(extract_coordinates_from_pdb(pdb2),dtype=torch.float) | |
# coords3 = torch.tensor(extract_coordinates_from_pdb(pdb3),dtype=torch.float) | |
# Create the HeteroData object | |
data = HeteroData() | |
# Define ligand node features | |
data["ligand"].x = torch.tensor(coords1, dtype=torch.float) | |
data["ligand"].pos = coords1 | |
# data["ligand"].y = torch.tensor(coords3[:len(coords1)], dtype=torch.float) | |
# Define receptor node features | |
data["receptor"].x = torch.tensor(coords2, dtype=torch.float) | |
data["receptor"].pos = coords2 | |
# data["receptor"].y = torch.tensor(coords3[len(coords1):], dtype=torch.float) | |
# Construct k-NN graph for ligand | |
ligand_edge_index = knn_graph(data["ligand"].pos, k=k) | |
data["ligand"].edge_index = ligand_edge_index | |
# Construct k-NN graph for receptor | |
receptor_edge_index = knn_graph(data["receptor"].pos, k=k) | |
data["receptor"].edge_index = receptor_edge_index | |
# Convert edge index to SparseTensor for ligand | |
data["ligand", "ligand"].edge_index = ligand_edge_index | |
# Convert edge index to SparseTensor for receptor | |
data["receptor", "receptor"].edge_index = receptor_edge_index | |
return data | |
def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordinates_tensor): | |
r""" | |
Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor. | |
Parameters: | |
- input_filename (str): Path to the original PDB file. | |
- output_filename (str): Path to the new PDB file to save updated coordinates. | |
- coordinates_tensor (torch.Tensor): Tensor of shape (1, N, 3) with transformed coordinates. | |
""" | |
# Convert the tensor to a list of tuples | |
new_coordinates = coordinates_tensor.squeeze(0).tolist() | |
# Create a parser and parse the structure | |
parser = PDB.PDBParser(QUIET=True) | |
structure = parser.get_structure("structure", input_filename) | |
# Flattened iterator for atoms to update coordinates | |
atom_iterator = (atom for model in structure for chain in model for residue in chain for atom in residue) | |
# Update each atom's coordinates | |
for atom, (new_x, new_y, new_z) in zip(atom_iterator, new_coordinates): | |
original_anisou = atom.get_anisou() | |
original_uij = atom.get_siguij() | |
original_tm= atom.get_sigatm() | |
original_occupancy = atom.get_occupancy() | |
original_bfactor = atom.get_bfactor() | |
original_altloc = atom.get_altloc() | |
original_fullname = atom.get_fullname() | |
original_serial_number = atom.get_serial_number() | |
original_element = atom.get_charge() | |
original_id = atom.get_full_id() | |
original_idx = atom.get_id() | |
original_level = atom.get_level() | |
original_name = atom.get_name() | |
original_parent = atom.get_parent() | |
original_radius = atom.get_radius() | |
# Update only the atom coordinates, keep other fields intact | |
atom.coord = np.array([new_x, new_y, new_z]) | |
# Reapply the preserved properties | |
atom.set_anisou(original_anisou) | |
atom.set_siguij(original_uij) | |
atom.set_sigatm(original_tm) | |
atom.set_occupancy(original_occupancy) | |
atom.set_bfactor(original_bfactor) | |
atom.set_altloc(original_altloc) | |
# atom.set_fullname(original_fullname) | |
atom.set_serial_number(original_serial_number) | |
atom.set_charge(original_element) | |
atom.set_radius(original_radius) | |
atom.set_parent(original_parent) | |
# atom.set_name(original_name) | |
# atom.set_leve | |
output_filename = "/tmp/" + output_filename | |
# Save the updated structure to a new PDB file | |
io = PDBIO() | |
io.set_structure(structure) | |
io.save(output_filename) | |
# Return the path to the updated PDB file | |
return output_filename | |
def merge_pdb_files(file1, file2, output_file): | |
r""" | |
Merges two PDB files by concatenating them without altering their contents. | |
Parameters: | |
- file1 (str): Path to the first PDB file (e.g., receptor). | |
- file2 (str): Path to the second PDB file (e.g., ligand). | |
- output_file (str): Path to the output file where the merged structure will be saved. | |
""" | |
output_file = "/tmp/" + output_file | |
with open(output_file, 'w') as outfile: | |
# Copy the contents of the first file | |
with open(file1, 'r') as f1: | |
lines = f1.readlines() | |
# Write all lines except the last 'END' line | |
outfile.writelines(lines[:-1]) | |
# Copy the contents of the second file | |
with open(file2, 'r') as f2: | |
outfile.write(f2.read()) | |
print(f"Merged PDB saved to {output_file}") | |
return output_file | |
class MPNNLayer(MessagePassing): | |
def __init__(self, emb_dim=64, edge_dim=4, aggr='add'): | |
r"""Message Passing Neural Network Layer | |
Args: | |
emb_dim: (int) - hidden dimension d | |
edge_dim: (int) - edge feature dimension d_e | |
aggr: (str) - aggregation function \oplus (sum/mean/max) | |
""" | |
# Set the aggregation function | |
super().__init__(aggr=aggr) | |
self.emb_dim = emb_dim | |
self.edge_dim = edge_dim | |
# MLP \psi for computing messages m_ij | |
# Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU | |
# dims: (2d + d_e) -> d | |
self.mlp_msg = Sequential( | |
Linear(2*emb_dim + edge_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), | |
Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU() | |
) | |
# MLP \phi for computing updated node features h_i^{l+1} | |
# Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU | |
# dims: 2d -> d | |
self.mlp_upd = Sequential( | |
Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), | |
Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU() | |
) | |
def forward(self, h, edge_index, edge_attr): | |
r""" | |
The forward pass updates node features h via one round of message passing. | |
As our MPNNLayer class inherits from the PyG MessagePassing parent class, | |
we simply need to call the propagate() function which starts the | |
message passing procedure: message() -> aggregate() -> update(). | |
The MessagePassing class handles most of the logic for the implementation. | |
To build custom GNNs, we only need to define our own message(), | |
aggregate(), and update() functions (defined subsequently). | |
Args: | |
h: (n, d) - initial node features | |
edge_index: (e, 2) - pairs of edges (i, j) | |
edge_attr: (e, d_e) - edge features | |
Returns: | |
out: (n, d) - updated node features | |
""" | |
out = self.propagate(edge_index, h=h, edge_attr=edge_attr) | |
return out | |
def message(self, h_i, h_j, edge_attr): | |
r"""Step (1) Message | |
The message() function constructs messages from source nodes j | |
to destination nodes i for each edge (i, j) in edge_index. | |
The arguments can be a bit tricky to understand: message() can take | |
any arguments that were initially passed to propagate. Additionally, | |
we can differentiate destination nodes and source nodes by appending | |
_i or _j to the variable name, e.g. for the node features h, we | |
can use h_i and h_j. | |
This part is critical to understand as the message() function | |
constructs messages for each edge in the graph. The indexing of the | |
original node features h (or other node variables) is handled under | |
the hood by PyG. | |
Args: | |
h_i: (e, d) - destination node features | |
h_j: (e, d) - source node features | |
edge_attr: (e, d_e) - edge features | |
Returns: | |
msg: (e, d) - messages m_ij passed through MLP \psi | |
""" | |
msg = torch.cat([h_i, h_j, edge_attr], dim=-1) | |
return self.mlp_msg(msg) | |
def aggregate(self, inputs, index): | |
r"""Step (2) Aggregate | |
The aggregate function aggregates the messages from neighboring nodes, | |
according to the chosen aggregation function ('sum' by default). | |
Args: | |
inputs: (e, d) - messages m_ij from destination to source nodes | |
index: (e, 1) - list of source nodes for each edge/message in input | |
Returns: | |
aggr_out: (n, d) - aggregated messages m_i | |
""" | |
return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr) | |
def update(self, aggr_out, h): | |
r""" | |
Step (3) Update | |
The update() function computes the final node features by combining the | |
aggregated messages with the initial node features. | |
update() takes the first argument aggr_out, the result of aggregate(), | |
as well as any optional arguments that were initially passed to | |
propagate(). E.g. in this case, we additionally pass h. | |
Args: | |
aggr_out: (n, d) - aggregated messages m_i | |
h: (n, d) - initial node features | |
Returns: | |
upd_out: (n, d) - updated node features passed through MLP \phi | |
""" | |
upd_out = torch.cat([h, aggr_out], dim=-1) | |
return self.mlp_upd(upd_out) | |
def __repr__(self) -> str: | |
return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})') | |
class MPNNModel(Module): | |
def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1): | |
r"""Message Passing Neural Network model for graph property prediction | |
Args: | |
num_layers: (int) - number of message passing layers L | |
emb_dim: (int) - hidden dimension d | |
in_dim: (int) - initial node feature dimension d_n | |
edge_dim: (int) - edge feature dimension d_e | |
out_dim: (int) - output dimension (fixed to 1) | |
""" | |
super().__init__() | |
# Linear projection for initial node features | |
# dim: d_n -> d | |
self.lin_in = Linear(in_dim, emb_dim) | |
# Stack of MPNN layers | |
self.convs = torch.nn.ModuleList() | |
for layer in range(num_layers): | |
self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add')) | |
# Global pooling/readout function R (mean pooling) | |
# PyG handles the underlying logic via global_mean_pool() | |
self.pool = global_mean_pool | |
# Linear prediction head | |
# dim: d -> out_dim | |
self.lin_pred = Linear(emb_dim, out_dim) | |
def forward(self, data): | |
r""" | |
Args: | |
data: (PyG.Data) - batch of PyG graphs | |
Returns: | |
out: (batch_size, out_dim) - prediction for each graph | |
""" | |
h = self.lin_in(data.x) # (n, d_n) -> (n, d) | |
for conv in self.convs: | |
h = h + conv(h, data.edge_index, data.edge_attr) # (n, d) -> (n, d) | |
# Note that we add a residual connection after each MPNN layer | |
h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d) | |
out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1) | |
return out.view(-1) | |
class EquivariantMPNNLayer(MessagePassing): | |
def __init__(self, emb_dim=64, aggr='add'): | |
r"""Message Passing Neural Network Layer | |
This layer is equivariant to 3D rotations and translations. | |
Args: | |
emb_dim: (int) - hidden dimension d | |
edge_dim: (int) - edge feature dimension d_e | |
aggr: (str) - aggregation function \oplus (sum/mean/max) | |
""" | |
# Set the aggregation function | |
super().__init__(aggr=aggr) | |
self.emb_dim = emb_dim | |
# | |
self.mlp_msg = Sequential( | |
Linear(2 * emb_dim + 1, emb_dim), | |
BatchNorm1d(emb_dim), | |
ReLU(), | |
Linear(emb_dim, emb_dim), | |
BatchNorm1d(emb_dim), | |
ReLU() | |
) | |
self.mlp_pos = Sequential( | |
Linear(emb_dim, emb_dim), | |
BatchNorm1d(emb_dim), | |
ReLU(), | |
Linear(emb_dim,1) | |
) # MLP \psi | |
self.mlp_upd = Sequential( | |
Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim,emb_dim), BatchNorm1d(emb_dim), ReLU() | |
) # MLP \phi | |
# =========================================== | |
def forward(self, h, pos, edge_index): | |
r""" | |
The forward pass updates node features h via one round of message passing. | |
Args: | |
h: (n, d) - initial node features | |
pos: (n, 3) - initial node coordinates | |
edge_index: (e, 2) - pairs of edges (i, j) | |
edge_attr: (e, d_e) - edge features | |
Returns: | |
out: [(n, d),(n,3)] - updated node features | |
""" | |
# | |
out = self.propagate(edge_index=edge_index, h=h, pos=pos) | |
return out | |
# ========================================== | |
# | |
def message(self, h_i,h_j,pos_i,pos_j): | |
# Compute distance between nodes i and j (Euclidean distance) | |
#distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1) | |
pos_diff = pos_i - pos_j | |
dists = torch.norm(pos_diff,dim=-1).unsqueeze(1) | |
# Concatenate node features, edge features, and distance | |
msg = torch.cat([h_i , h_j, dists], dim=-1) | |
msg = self.mlp_msg(msg) | |
pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1) | |
# (e, d) | |
return msg , pos_diff | |
# ... | |
# | |
def aggregate(self, inputs, index): | |
r"""The aggregate function aggregates the messages from neighboring nodes, | |
according to the chosen aggregation function ('sum' by default). | |
Args: | |
inputs: (e, d) - messages m_ij from destination to source nodes | |
index: (e, 1) - list of source nodes for each edge/message in input | |
Returns: | |
aggr_out: (n, d) - aggregated messages m_i | |
""" | |
msgs , pos_diffs = inputs | |
msg_aggr = scatter(msgs, index , dim = self.node_dim , reduce = self.aggr) | |
pos_aggr = scatter(pos_diffs, index, dim = self.node_dim , reduce = "mean") | |
return msg_aggr , pos_aggr | |
def update(self, aggr_out, h , pos): | |
msg_aggr , pos_aggr = aggr_out | |
upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1)) | |
upd_pos = pos + pos_aggr | |
return upd_out , upd_pos | |
def __repr__(self) -> str: | |
return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})') | |
class FinalMPNNModel(MPNNModel): | |
def __init__(self, num_layers=4, emb_dim=64, in_dim=3, num_heads = 2): | |
r"""Message Passing Neural Network model for graph property prediction | |
This model uses both node features and coordinates as inputs, and | |
is invariant to 3D rotations and translations (the constituent MPNN layers | |
are equivariant to 3D rotations and translations). | |
Args: | |
num_layers: (int) - number of message passing layers L | |
emb_dim: (int) - hidden dimension d | |
in_dim: (int) - initial node feature dimension d_n | |
edge_dim: (int) - edge feature dimension d_e | |
out_dim: (int) - output dimension (fixed to 1) | |
""" | |
super().__init__() | |
# Linear projection for initial node features | |
# dim: d_n -> d | |
self.lin_in = Linear(in_dim, emb_dim) | |
self.equiv_layer = EquivariantMPNNLayer(emb_dim=emb_dim) | |
# Stack of MPNN layers | |
self.convs = torch.nn.ModuleList() | |
for layer in range(num_layers): | |
self.convs.append(EquivariantMPNNLayer(emb_dim, aggr='add')) | |
self.cross_attention = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True) | |
self.fc_rotation = nn.Linear(emb_dim, 9) | |
self.fc_translation = nn.Linear(emb_dim, 3) | |
# Global pooling/readout function R (mean pooling) | |
# PyG handles the underlying logic via global_mean_pool() | |
# self.pool = global_mean_pool | |
def naive_single(self, receptor, ligand , receptor_edge_index , ligand_edge_index): | |
r""" | |
Processes a single receptor-ligand pair. | |
Args: | |
receptor: Tensor of shape (1, num_receptor_atoms, 3) (receptor coordinates) | |
ligand: Tensor of shape (1, num_ligand_atoms, 3) (ligand coordinates) | |
Returns: | |
rotation_matrix: Tensor of shape (1, 3, 3) predicted rotation matrix for the ligand. | |
translation_vector: Tensor of shape (1, 3) predicted translation vector for the ligand. | |
""" | |
# h_receptor = receptor # Initial node features for the receptor | |
# h_ligand = ligand | |
h_receptor = self.lin_in(receptor) | |
h_ligand = self.lin_in(ligand) # Initial node features for the ligand | |
pos_receptor = receptor # Initial positions | |
pos_ligand = ligand | |
for layer in self.convs: | |
# Apply the equivariant message-passing layer for both receptor and ligand | |
h_receptor, pos_receptor = layer(h_receptor, pos_receptor,receptor_edge_index ) | |
h_ligand, pos_ligand = layer(h_ligand, pos_ligand, ligand_edge_index) | |
# print("Shape of h_receptor:", h_receptor.shape) | |
# print("Shape of h_ligand:", h_ligand.shape) | |
# Pass the layer outputs through MLPs for embeddings | |
emb_features_receptor = h_receptor | |
emb_features_ligand = h_ligand | |
attn_output, _ = self.cross_attention(emb_features_receptor, emb_features_ligand, emb_features_ligand) | |
rotation_matrix = self.fc_rotation(attn_output.mean(dim=0)) | |
rotation_matrix = rotation_matrix.view(-1, 3, 3) | |
translation_vector = self.fc_translation(attn_output.mean(dim=0)) | |
return rotation_matrix, translation_vector | |
def forward(self, data): | |
r""" | |
The main forward pass of the model. | |
Args: | |
batch: Same as in forward_rot_trans. | |
Returns: | |
transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3) | |
representing the transformed ligand coordinates after applying the predicted | |
rotation and translation. | |
""" | |
receptor = data['receptor']['pos'] | |
ligand = data['ligand']['pos'] | |
receptor_edge_index = data['receptor']['edge_index'] | |
ligand_edge_index = data['ligand']['edge_index'] | |
rotation_matrix, translation_vector = self.naive_single(receptor, ligand,receptor_edge_index , ligand_edge_index) | |
# for i in range(len(ligands)): | |
# ligands[i] = ligands[i] @ rotation_matrix[i] + translation_vector[i] | |
ligands = data['ligand']['pos'] @ rotation_matrix + translation_vector | |
return ligands | |
class FinalMPNNModelight(pl.LightningModule): | |
def __init__(self, num_layers=4, emb_dim=32, in_dim=3, num_heads=1, lr=1e-4): | |
super().__init__() | |
self.lin_in = nn.Linear(in_dim, emb_dim) | |
self.convs = nn.ModuleList([EquivariantMPNNLayer(emb_dim, aggr='add') for _ in range(num_layers)]) | |
self.cross_attention = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True) | |
self.fc_rotation = nn.Linear(emb_dim, 9) | |
self.fc_translation = nn.Linear(emb_dim, 3) | |
self.lr = lr | |
def naive_single(self, receptor, ligand, receptor_edge_index, ligand_edge_index): | |
h_receptor = self.lin_in(receptor) | |
h_ligand = self.lin_in(ligand) | |
pos_receptor, pos_ligand = receptor, ligand | |
for layer in self.convs: | |
h_receptor, pos_receptor = layer(h_receptor, pos_receptor, receptor_edge_index) | |
h_ligand, pos_ligand = layer(h_ligand, pos_ligand, ligand_edge_index) | |
attn_output, _ = self.cross_attention(h_receptor, h_ligand, h_ligand) | |
rotation_matrix = self.fc_rotation(attn_output.mean(dim=0)).view(-1, 3, 3) | |
translation_vector = self.fc_translation(attn_output.mean(dim=0)) | |
return rotation_matrix, translation_vector | |
def forward(self, data): | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
receptor = data['receptor']['pos'].to(device) | |
ligand = data['ligand']['pos'].to(device) | |
receptor_edge_index = data['receptor', 'receptor']['edge_index'].to(device) | |
ligand_edge_index = data['ligand', 'ligand']['edge_index'].to(device) | |
rotation_matrix, translation_vector = self.naive_single(receptor, ligand, receptor_edge_index, ligand_edge_index) | |
# transformed_ligand = torch.matmul(ligand ,rotation_matrix) + translation_vector | |
return rotation_matrix, translation_vector | |
def training_step(self, batch, batch_idx): | |
ligand_pred = self(batch) | |
ligand_true = batch['ligand']['y'] | |
loss = F.mse_loss(ligand_pred.squeeze(0), ligand_true) | |
self.log('train_loss', loss, batch_size=8) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
ligand_pred = self(batch) | |
ligand_true = batch['ligand']['y'] | |
loss = F.l1_loss(ligand_pred.squeeze(0), ligand_true) | |
self.log('val_loss', loss, prog_bar=True, batch_size=8) | |
return loss | |
def test_step(self, batch, batch_idx): | |
ligand_pred = self(batch) | |
ligand_true = batch['ligand']['y'] | |
loss = F.l1_loss(ligand_pred.squeeze(0), ligand_true) | |
self.log('test_loss', loss, prog_bar=True, batch_size=8) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer, mode="min", factor=0.1, patience=5 | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"monitor": "val_loss", # Monitor validation loss to adjust the learning rate | |
}, | |
} | |
model_path = "./EquiMPNN-epoch=413-val_loss=9.25-val_acc=0.00.ckpt" | |
model = FinalMPNNModelight.load_from_checkpoint(model_path) | |
trainer = pl.Trainer( | |
fast_dev_run=False, | |
accelerator="gpu" if torch.cuda.is_available() else "cpu", | |
precision="bf16-mixed", | |
devices=1, | |
) | |
model.eval() | |
def predict (input_seq_1, input_msa_1, input_protein_1, input_seq_2,input_msa_2, input_protein_2): | |
start_time = time.time() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
data = create_graph(input_protein_1, input_protein_2, k=10) | |
R_chain, L_chain = ["R"], ["L"] | |
with torch.no_grad(): | |
mat, vect = model(data) | |
mat = mat.to(device) | |
vect = vect.to(device) | |
ligand1 = torch.tensor(extract_coordinates_from_pdb(input_protein_1),dtype=torch.float).to(device) | |
# receptor1 = torch.tensor(extract_coordinates_from_pdb(input_protein_2),dtype=torch.float).to(device) | |
transformed_ligand = torch.matmul(ligand1, mat) + vect | |
# transformed_receptor = torch.matmul(receptor1, mat) + vect | |
file1 = update_pdb_coordinates_from_tensor(input_protein_1, "holo_ligand.pdb", transformed_ligand) | |
# file2 = update_pdb_coordinates_from_tensor(input_protein_2, "holo_receptor.pdb", transformed_receptor) | |
out_pdb = merge_pdb_files(file1,input_protein_2,"output.pdb") | |
# return an output pdb file with the protein and two chains A and B. | |
# also return a JSON with any metrics you want to report | |
metrics = {"mean_plddt": 80, "binding_affinity": 2} | |
# native = './test_out (1).pdb' | |
# decoys = out_pdb | |
# bdq = BiotiteDockQ( | |
# native=native, decoys=decoys, | |
# # These are optional and if not specified will be assigned based on number of atoms (receptor > ligand) | |
# native_receptor_chain=R_chain, | |
# native_ligand_chain=L_chain, | |
# decoy_receptor_chain=R_chain, | |
# decoy_ligand_chain=L_chain, | |
# ) | |
# dockq = bdq.calculate() | |
# metrics['DockQ'] = dockq | |
end_time = time.time() | |
run_time = end_time - start_time | |
return out_pdb,json.dumps(metrics), run_time | |
with gr.Blocks() as app: | |
gr.Markdown("# Template for inference") | |
gr.Markdown("EquiMPNN MOdel") | |
with gr.Row(): | |
with gr.Column(): | |
input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)") | |
input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)") | |
input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)") | |
with gr.Column(): | |
input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)") | |
input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)") | |
input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)") | |
# define any options here | |
# for automated inference the default options are used | |
# slider_option = gr.Slider(0,10, label="Slider Option") | |
# checkbox_option = gr.Checkbox(label="Checkbox Option") | |
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option") | |
btn = gr.Button("Run Inference") | |
gr.Examples( | |
[ | |
[ | |
"GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
"3v1c_A.pdb", | |
"GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
"3v1c_B.pdb", | |
], | |
], | |
[input_seq_1, input_protein_1, input_seq_2, input_protein_2], | |
) | |
reps = [ | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "A", | |
"color": "whiteCarbon", | |
}, | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "B", | |
"color": "greenCarbon", | |
}, | |
{ | |
"model": 0, | |
"chain": "A", | |
"style": "stick", | |
"sidechain": True, | |
"color": "whiteCarbon", | |
}, | |
{ | |
"model": 0, | |
"chain": "B", | |
"style": "stick", | |
"sidechain": True, | |
"color": "greenCarbon" | |
} | |
] | |
# outputs | |
out = Molecule3D(reps=reps) | |
metrics = gr.JSON(label="Metrics") | |
run_time = gr.Textbox(label="Runtime") | |
btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2], outputs=[out, metrics, run_time]) | |
app.launch() | |