Spaces:
Sleeping
Sleeping
fix inference_app.py
Browse files- inference_app.py +17 -17
inference_app.py
CHANGED
@@ -106,7 +106,7 @@ from Bio import PDB
|
|
106 |
from Bio.PDB.PDBIO import PDBIO
|
107 |
|
108 |
def extract_coordinates_from_pdb(filename):
|
109 |
-
"""
|
110 |
Extracts atom coordinates from a PDB file and returns them as a list of tuples.
|
111 |
Each tuple contains (x, y, z) coordinates of an atom.
|
112 |
"""
|
@@ -288,7 +288,7 @@ class PairedPDB(HeteroData): # type: ignore
|
|
288 |
|
289 |
#create_graph takes inputs apo_ligand, apo_residue and paired holo as pdb3(ground truth).
|
290 |
def create_graph(pdb1, pdb2, pdb3='/home/sukanya/iitm_bisect_pinder_submission/test_out.pdb', k=5):
|
291 |
-
"""
|
292 |
Create a heterogeneous graph from two PDB files, with the ligand and receptor
|
293 |
as separate nodes, and their respective features and edges.
|
294 |
|
@@ -336,7 +336,7 @@ def create_graph(pdb1, pdb2, pdb3='/home/sukanya/iitm_bisect_pinder_submission/t
|
|
336 |
|
337 |
|
338 |
def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordinates_tensor):
|
339 |
-
"""
|
340 |
Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor.
|
341 |
|
342 |
Parameters:
|
@@ -400,7 +400,7 @@ def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordina
|
|
400 |
return output_filename
|
401 |
|
402 |
def merge_pdb_files(file1, file2, output_file):
|
403 |
-
"""
|
404 |
Merges two PDB files by concatenating them without altering their contents.
|
405 |
|
406 |
Parameters:
|
@@ -421,7 +421,7 @@ def merge_pdb_files(file1, file2, output_file):
|
|
421 |
|
422 |
class MPNNLayer(MessagePassing):
|
423 |
def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
|
424 |
-
"""Message Passing Neural Network Layer
|
425 |
|
426 |
Args:
|
427 |
emb_dim: (int) - hidden dimension d
|
@@ -451,7 +451,7 @@ class MPNNLayer(MessagePassing):
|
|
451 |
)
|
452 |
|
453 |
def forward(self, h, edge_index, edge_attr):
|
454 |
-
"""
|
455 |
The forward pass updates node features h via one round of message passing.
|
456 |
|
457 |
As our MPNNLayer class inherits from the PyG MessagePassing parent class,
|
@@ -474,7 +474,7 @@ class MPNNLayer(MessagePassing):
|
|
474 |
return out
|
475 |
|
476 |
def message(self, h_i, h_j, edge_attr):
|
477 |
-
"""Step (1) Message
|
478 |
|
479 |
The message() function constructs messages from source nodes j
|
480 |
to destination nodes i for each edge (i, j) in edge_index.
|
@@ -502,7 +502,7 @@ class MPNNLayer(MessagePassing):
|
|
502 |
return self.mlp_msg(msg)
|
503 |
|
504 |
def aggregate(self, inputs, index):
|
505 |
-
"""Step (2) Aggregate
|
506 |
|
507 |
The aggregate function aggregates the messages from neighboring nodes,
|
508 |
according to the chosen aggregation function ('sum' by default).
|
@@ -517,7 +517,7 @@ class MPNNLayer(MessagePassing):
|
|
517 |
return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
|
518 |
|
519 |
def update(self, aggr_out, h):
|
520 |
-
"""
|
521 |
Step (3) Update
|
522 |
|
523 |
The update() function computes the final node features by combining the
|
@@ -541,7 +541,7 @@ class MPNNLayer(MessagePassing):
|
|
541 |
return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')
|
542 |
class MPNNModel(Module):
|
543 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
|
544 |
-
"""Message Passing Neural Network model for graph property prediction
|
545 |
|
546 |
Args:
|
547 |
num_layers: (int) - number of message passing layers L
|
@@ -570,7 +570,7 @@ class MPNNModel(Module):
|
|
570 |
self.lin_pred = Linear(emb_dim, out_dim)
|
571 |
|
572 |
def forward(self, data):
|
573 |
-
"""
|
574 |
Args:
|
575 |
data: (PyG.Data) - batch of PyG graphs
|
576 |
|
@@ -592,7 +592,7 @@ class MPNNModel(Module):
|
|
592 |
|
593 |
class EquivariantMPNNLayer(MessagePassing):
|
594 |
def __init__(self, emb_dim=64, aggr='add'):
|
595 |
-
"""Message Passing Neural Network Layer
|
596 |
|
597 |
This layer is equivariant to 3D rotations and translations.
|
598 |
|
@@ -630,7 +630,7 @@ class EquivariantMPNNLayer(MessagePassing):
|
|
630 |
# ===========================================
|
631 |
|
632 |
def forward(self, h, pos, edge_index):
|
633 |
-
"""
|
634 |
The forward pass updates node features h via one round of message passing.
|
635 |
|
636 |
Args:
|
@@ -667,7 +667,7 @@ class EquivariantMPNNLayer(MessagePassing):
|
|
667 |
# ...
|
668 |
#
|
669 |
def aggregate(self, inputs, index):
|
670 |
-
"""The aggregate function aggregates the messages from neighboring nodes,
|
671 |
according to the chosen aggregation function ('sum' by default).
|
672 |
|
673 |
Args:
|
@@ -701,7 +701,7 @@ class EquivariantMPNNLayer(MessagePassing):
|
|
701 |
|
702 |
class FinalMPNNModel(MPNNModel):
|
703 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=3, num_heads = 2):
|
704 |
-
"""Message Passing Neural Network model for graph property prediction
|
705 |
|
706 |
This model uses both node features and coordinates as inputs, and
|
707 |
is invariant to 3D rotations and translations (the constituent MPNN layers
|
@@ -734,7 +734,7 @@ class FinalMPNNModel(MPNNModel):
|
|
734 |
# self.pool = global_mean_pool
|
735 |
|
736 |
def naive_single(self, receptor, ligand , receptor_edge_index , ligand_edge_index):
|
737 |
-
"""
|
738 |
Processes a single receptor-ligand pair.
|
739 |
|
740 |
Args:
|
@@ -775,7 +775,7 @@ class FinalMPNNModel(MPNNModel):
|
|
775 |
|
776 |
|
777 |
def forward(self, data):
|
778 |
-
"""
|
779 |
The main forward pass of the model.
|
780 |
|
781 |
Args:
|
|
|
106 |
from Bio.PDB.PDBIO import PDBIO
|
107 |
|
108 |
def extract_coordinates_from_pdb(filename):
|
109 |
+
r"""
|
110 |
Extracts atom coordinates from a PDB file and returns them as a list of tuples.
|
111 |
Each tuple contains (x, y, z) coordinates of an atom.
|
112 |
"""
|
|
|
288 |
|
289 |
#create_graph takes inputs apo_ligand, apo_residue and paired holo as pdb3(ground truth).
|
290 |
def create_graph(pdb1, pdb2, pdb3='/home/sukanya/iitm_bisect_pinder_submission/test_out.pdb', k=5):
|
291 |
+
r"""
|
292 |
Create a heterogeneous graph from two PDB files, with the ligand and receptor
|
293 |
as separate nodes, and their respective features and edges.
|
294 |
|
|
|
336 |
|
337 |
|
338 |
def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordinates_tensor):
|
339 |
+
r"""
|
340 |
Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor.
|
341 |
|
342 |
Parameters:
|
|
|
400 |
return output_filename
|
401 |
|
402 |
def merge_pdb_files(file1, file2, output_file):
|
403 |
+
r"""
|
404 |
Merges two PDB files by concatenating them without altering their contents.
|
405 |
|
406 |
Parameters:
|
|
|
421 |
|
422 |
class MPNNLayer(MessagePassing):
|
423 |
def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
|
424 |
+
r"""Message Passing Neural Network Layer
|
425 |
|
426 |
Args:
|
427 |
emb_dim: (int) - hidden dimension d
|
|
|
451 |
)
|
452 |
|
453 |
def forward(self, h, edge_index, edge_attr):
|
454 |
+
r"""
|
455 |
The forward pass updates node features h via one round of message passing.
|
456 |
|
457 |
As our MPNNLayer class inherits from the PyG MessagePassing parent class,
|
|
|
474 |
return out
|
475 |
|
476 |
def message(self, h_i, h_j, edge_attr):
|
477 |
+
r"""Step (1) Message
|
478 |
|
479 |
The message() function constructs messages from source nodes j
|
480 |
to destination nodes i for each edge (i, j) in edge_index.
|
|
|
502 |
return self.mlp_msg(msg)
|
503 |
|
504 |
def aggregate(self, inputs, index):
|
505 |
+
r"""Step (2) Aggregate
|
506 |
|
507 |
The aggregate function aggregates the messages from neighboring nodes,
|
508 |
according to the chosen aggregation function ('sum' by default).
|
|
|
517 |
return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
|
518 |
|
519 |
def update(self, aggr_out, h):
|
520 |
+
r"""
|
521 |
Step (3) Update
|
522 |
|
523 |
The update() function computes the final node features by combining the
|
|
|
541 |
return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')
|
542 |
class MPNNModel(Module):
|
543 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
|
544 |
+
r"""Message Passing Neural Network model for graph property prediction
|
545 |
|
546 |
Args:
|
547 |
num_layers: (int) - number of message passing layers L
|
|
|
570 |
self.lin_pred = Linear(emb_dim, out_dim)
|
571 |
|
572 |
def forward(self, data):
|
573 |
+
r"""
|
574 |
Args:
|
575 |
data: (PyG.Data) - batch of PyG graphs
|
576 |
|
|
|
592 |
|
593 |
class EquivariantMPNNLayer(MessagePassing):
|
594 |
def __init__(self, emb_dim=64, aggr='add'):
|
595 |
+
r"""Message Passing Neural Network Layer
|
596 |
|
597 |
This layer is equivariant to 3D rotations and translations.
|
598 |
|
|
|
630 |
# ===========================================
|
631 |
|
632 |
def forward(self, h, pos, edge_index):
|
633 |
+
r"""
|
634 |
The forward pass updates node features h via one round of message passing.
|
635 |
|
636 |
Args:
|
|
|
667 |
# ...
|
668 |
#
|
669 |
def aggregate(self, inputs, index):
|
670 |
+
r"""The aggregate function aggregates the messages from neighboring nodes,
|
671 |
according to the chosen aggregation function ('sum' by default).
|
672 |
|
673 |
Args:
|
|
|
701 |
|
702 |
class FinalMPNNModel(MPNNModel):
|
703 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=3, num_heads = 2):
|
704 |
+
r"""Message Passing Neural Network model for graph property prediction
|
705 |
|
706 |
This model uses both node features and coordinates as inputs, and
|
707 |
is invariant to 3D rotations and translations (the constituent MPNN layers
|
|
|
734 |
# self.pool = global_mean_pool
|
735 |
|
736 |
def naive_single(self, receptor, ligand , receptor_edge_index , ligand_edge_index):
|
737 |
+
r"""
|
738 |
Processes a single receptor-ligand pair.
|
739 |
|
740 |
Args:
|
|
|
775 |
|
776 |
|
777 |
def forward(self, data):
|
778 |
+
r"""
|
779 |
The main forward pass of the model.
|
780 |
|
781 |
Args:
|