Spaces:
Sleeping
Sleeping
fix inference_app.py
Browse files- inference_app.py +21 -24
inference_app.py
CHANGED
@@ -104,28 +104,11 @@ def get_system(system_id: str) -> PinderSystem:
|
|
104 |
return PinderSystem(system_id)
|
105 |
from Bio import PDB
|
106 |
from Bio.PDB.PDBIO import PDBIO
|
|
|
|
|
|
|
107 |
|
108 |
-
def extract_coordinates_from_pdb(filename):
|
109 |
-
r"""
|
110 |
-
Extracts atom coordinates from a PDB file and returns them as a list of tuples.
|
111 |
-
Each tuple contains (x, y, z) coordinates of an atom.
|
112 |
-
"""
|
113 |
-
parser = PDB.PDBParser(QUIET=True)
|
114 |
-
structure = parser.get_structure("structure", filename)
|
115 |
-
|
116 |
-
coordinates = []
|
117 |
-
|
118 |
-
# Loop through each model, chain, residue, and atom to collect coordinates
|
119 |
-
for model in structure:
|
120 |
-
for chain in model:
|
121 |
-
for residue in chain:
|
122 |
-
# Retrieve atoms and their coordinates
|
123 |
-
for atom in residue:
|
124 |
-
xyz = atom.coord # Coordinates are in a numpy array
|
125 |
-
# Append the coordinates (x, y, z) as a tuple
|
126 |
-
coordinates.append((xyz[0], xyz[1], xyz[2]))
|
127 |
|
128 |
-
return coordinates
|
129 |
log = setup_logger(__name__)
|
130 |
|
131 |
try:
|
@@ -302,8 +285,8 @@ def create_graph(pdb1, pdb2, k=5):
|
|
302 |
HeteroData: A PyG HeteroData object containing ligand and receptor data.
|
303 |
"""
|
304 |
# Extract coordinates from PDB files
|
305 |
-
coords1 = torch.tensor(
|
306 |
-
coords2 = torch.tensor(
|
307 |
# coords3 = torch.tensor(extract_coordinates_from_pdb(pdb3),dtype=torch.float)
|
308 |
# Create the HeteroData object
|
309 |
data = HeteroData()
|
@@ -422,6 +405,7 @@ def merge_pdb_files(file1, file2, output_file):
|
|
422 |
|
423 |
print(f"Merged PDB saved to {output_file}")
|
424 |
return output_file
|
|
|
425 |
class MPNNLayer(MessagePassing):
|
426 |
def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
|
427 |
r"""Message Passing Neural Network Layer
|
@@ -892,21 +876,34 @@ def predict (input_seq_1, input_msa_1, input_protein_1, input_seq_2,input_msa_2,
|
|
892 |
start_time = time.time()
|
893 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
894 |
data = create_graph(input_protein_1, input_protein_2, k=10)
|
895 |
-
|
896 |
with torch.no_grad():
|
897 |
mat, vect = model(data)
|
898 |
mat = mat.to(device)
|
899 |
vect = vect.to(device)
|
900 |
-
ligand1 = torch.tensor(
|
901 |
# receptor1 = torch.tensor(extract_coordinates_from_pdb(input_protein_2),dtype=torch.float).to(device)
|
902 |
transformed_ligand = torch.matmul(ligand1, mat) + vect
|
903 |
# transformed_receptor = torch.matmul(receptor1, mat) + vect
|
904 |
file1 = update_pdb_coordinates_from_tensor(input_protein_1, "holo_ligand.pdb", transformed_ligand)
|
905 |
# file2 = update_pdb_coordinates_from_tensor(input_protein_2, "holo_receptor.pdb", transformed_receptor)
|
906 |
out_pdb = merge_pdb_files(file1,input_protein_2,"output.pdb")
|
|
|
907 |
# return an output pdb file with the protein and two chains A and B.
|
908 |
# also return a JSON with any metrics you want to report
|
909 |
metrics = {"mean_plddt": 80, "binding_affinity": 2}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
910 |
end_time = time.time()
|
911 |
run_time = end_time - start_time
|
912 |
|
|
|
104 |
return PinderSystem(system_id)
|
105 |
from Bio import PDB
|
106 |
from Bio.PDB.PDBIO import PDBIO
|
107 |
+
from pinder.core.structure.atoms import atom_array_from_pdb_file
|
108 |
+
from pathlib import Path
|
109 |
+
from pinder.eval.dockq.biotite_dockq import BiotiteDockQ
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
|
|
112 |
log = setup_logger(__name__)
|
113 |
|
114 |
try:
|
|
|
285 |
HeteroData: A PyG HeteroData object containing ligand and receptor data.
|
286 |
"""
|
287 |
# Extract coordinates from PDB files
|
288 |
+
coords1 = torch.tensor(atom_array_from_pdb_file(pdb1),dtype=torch.float)
|
289 |
+
coords2 = torch.tensor(atom_array_from_pdb_file(pdb2),dtype=torch.float)
|
290 |
# coords3 = torch.tensor(extract_coordinates_from_pdb(pdb3),dtype=torch.float)
|
291 |
# Create the HeteroData object
|
292 |
data = HeteroData()
|
|
|
405 |
|
406 |
print(f"Merged PDB saved to {output_file}")
|
407 |
return output_file
|
408 |
+
|
409 |
class MPNNLayer(MessagePassing):
|
410 |
def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
|
411 |
r"""Message Passing Neural Network Layer
|
|
|
876 |
start_time = time.time()
|
877 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
878 |
data = create_graph(input_protein_1, input_protein_2, k=10)
|
879 |
+
R_chain, L_chain = ["R"], ["L"]
|
880 |
with torch.no_grad():
|
881 |
mat, vect = model(data)
|
882 |
mat = mat.to(device)
|
883 |
vect = vect.to(device)
|
884 |
+
ligand1 = torch.tensor(atom_array_from_pdb_file(input_protein_1),dtype=torch.float).to(device)
|
885 |
# receptor1 = torch.tensor(extract_coordinates_from_pdb(input_protein_2),dtype=torch.float).to(device)
|
886 |
transformed_ligand = torch.matmul(ligand1, mat) + vect
|
887 |
# transformed_receptor = torch.matmul(receptor1, mat) + vect
|
888 |
file1 = update_pdb_coordinates_from_tensor(input_protein_1, "holo_ligand.pdb", transformed_ligand)
|
889 |
# file2 = update_pdb_coordinates_from_tensor(input_protein_2, "holo_receptor.pdb", transformed_receptor)
|
890 |
out_pdb = merge_pdb_files(file1,input_protein_2,"output.pdb")
|
891 |
+
|
892 |
# return an output pdb file with the protein and two chains A and B.
|
893 |
# also return a JSON with any metrics you want to report
|
894 |
metrics = {"mean_plddt": 80, "binding_affinity": 2}
|
895 |
+
native = './test_out (1)'
|
896 |
+
decoys = out_pdb
|
897 |
+
bdq = BiotiteDockQ(
|
898 |
+
native=native, decoys=decoys,
|
899 |
+
# These are optional and if not specified will be assigned based on number of atoms (receptor > ligand)
|
900 |
+
native_receptor_chain=R_chain,
|
901 |
+
native_ligand_chain=L_chain,
|
902 |
+
decoy_receptor_chain=R_chain,
|
903 |
+
decoy_ligand_chain=L_chain,
|
904 |
+
)
|
905 |
+
dockq = bdq.calculate()
|
906 |
+
metrics['DockQ'] = dockq
|
907 |
end_time = time.time()
|
908 |
run_time = end_time - start_time
|
909 |
|