simonduerr commited on
Commit
2596438
·
verified ·
1 Parent(s): 77c7f7f

Create inference_app.py

Browse files
Files changed (1) hide show
  1. inference_app.py +220 -0
inference_app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from pathlib import Path
3
+ import time
4
+ from biotite.application.autodock import VinaApp
5
+
6
+ import gradio as gr
7
+
8
+ from gradio_molecule3d import Molecule3D
9
+ from gradio_molecule2d import molecule2d
10
+ import numpy as np
11
+ from rdkit import Chem
12
+ from rdkit.Chem import AllChem
13
+ import pandas as pd
14
+ from biotite.structure import centroid, from_template
15
+ from biotite.structure.io import load_structure
16
+ from biotite.structure.io.mol import MOLFile, SDFile
17
+ from biotite.structure.io.pdb import PDBFile
18
+
19
+ from plinder.eval.docking.write_scores import evaluate
20
+
21
+
22
+ EVAL_METRICS = ["system", "LDDT-PLI", "LDDT-LP", "BISY-RMSD"]
23
+
24
+
25
+ def vina(
26
+ ligand, receptor, pocket_center, output_folder: Path, size=10, max_num_poses=5
27
+ ):
28
+ app = VinaApp(
29
+ ligand,
30
+ receptor,
31
+ center=pocket_center,
32
+ size=[size, size, size],
33
+ )
34
+ app.set_max_number_of_models(max_num_poses)
35
+ app.start()
36
+ app.join()
37
+ docked_ligand = from_template(ligand, app.get_ligand_coord())
38
+ docked_ligand = docked_ligand[..., ~np.isnan(docked_ligand.coord[0]).any(axis=-1)]
39
+ output_files = []
40
+ for i in range(max_num_poses):
41
+ sdf_file = MOLFile()
42
+ sdf_file.set_structure(docked_ligand[i])
43
+ output_file = output_folder / f"docked_ligand_{i}.sdf"
44
+ sdf_file.write(output_file)
45
+ output_files.append(output_file)
46
+ return output_files
47
+
48
+
49
+ def predict(
50
+ input_sequence: str,
51
+ input_ligand: str,
52
+ input_msa: gr.File | None = None,
53
+ input_protein: gr.File | None = None,
54
+ max_num_poses: int = 1,
55
+ ):
56
+ """
57
+ Main prediction function that calls ligsite and smina
58
+ Parameters
59
+ ----------
60
+ input_sequence: str
61
+ monomer sequence
62
+ input_ligand: str
63
+ ligand as SMILES string
64
+ input_msa: gradio.File | None
65
+ Gradio file object to MSA a3m file
66
+ input_protein: gradio.File | None
67
+ Gradio file object to monomer protein structure as CIF file
68
+ max_num_poses: int
69
+ Number of poses to generate
70
+ Returns
71
+ -------
72
+ output_structures: tuple
73
+ (output_protein, output_ligand_sdf)
74
+ run_time: float
75
+ run time of the program
76
+ """
77
+ start_time = time.time()
78
+
79
+ if input_protein is None:
80
+ raise gr.Error("need input_protein")
81
+ print(input_protein)
82
+ ligand_file = Path(input_protein).parent / "ligand.sdf"
83
+ print(ligand_file)
84
+ conformer = Chem.AddHs(Chem.MolFromSmiles(input_ligand))
85
+ AllChem.EmbedMolecule(conformer)
86
+ AllChem.MMFFOptimizeMolecule(conformer)
87
+ Chem.SDWriter(ligand_file).write(conformer)
88
+ ligand = SDFile.read(ligand_file).record.get_structure()
89
+ receptor = load_structure(input_protein, include_bonds=True)
90
+ docking_poses = vina(
91
+ ligand,
92
+ receptor,
93
+ centroid(receptor),
94
+ Path(input_protein).parent,
95
+ max_num_poses=max_num_poses,
96
+ )
97
+ end_time = time.time()
98
+ run_time = end_time - start_time
99
+ pdb_file = PDBFile()
100
+ pdb_file.set_structure(receptor)
101
+ output_pdb = Path(input_protein).parent / "receptor.pdb"
102
+ pdb_file.write(output_pdb)
103
+ return [str(output_pdb), str(docking_poses[0])], run_time
104
+
105
+
106
+ def get_metrics(
107
+ system_id: str,
108
+ receptor_file: Path,
109
+ ligand_file: Path,
110
+ flexible: bool = True,
111
+ posebusters: bool = True,
112
+ ) -> tuple[pd.DataFrame, float]:
113
+ start_time = time.time()
114
+ metrics = pd.DataFrame(
115
+ [
116
+ evaluate(
117
+ model_system_id=system_id,
118
+ reference_system_id=system_id,
119
+ receptor_file=receptor_file,
120
+ ligand_file_list=[Path(ligand_file)],
121
+ flexible=flexible,
122
+ posebusters=posebusters,
123
+ posebusters_full=False,
124
+ ).get("LIG_0", {})
125
+ ]
126
+ )
127
+ if posebusters:
128
+ metrics["posebusters"] = metrics[
129
+ [col for col in metrics.columns if col.startswith("posebusters_")]
130
+ ].sum(axis=1)
131
+ metrics["posebusters_valid"] = metrics[
132
+ [col for col in metrics.columns if col.startswith("posebusters_")]
133
+ ].sum(axis=1) == 20
134
+ columns = ["reference", "lddt_pli_ave", "lddt_lp_ave", "bisy_rmsd_ave"]
135
+ if flexible:
136
+ columns.extend(["lddt", "bb_lddt"])
137
+ if posebusters:
138
+ columns.extend([col for col in metrics.columns if col.startswith("posebusters")])
139
+
140
+ metrics = metrics[columns].copy()
141
+ mapping = {
142
+ "lddt_pli_ave": "LDDT-PLI",
143
+ "lddt_lp_ave": "LDDT-LP",
144
+ "bisy_rmsd_ave": "BISY-RMSD",
145
+ "reference": "system",
146
+ }
147
+ if flexible:
148
+ mapping["lddt"] = "LDDT"
149
+ mapping["bb_lddt"] = "Backbone LDDT"
150
+ if posebusters:
151
+ mapping["posebusters"] = "PoseBusters #checks"
152
+ mapping["posebusters_valid"] = "PoseBusters valid"
153
+ metrics.rename(
154
+ columns=mapping,
155
+ inplace=True,
156
+ )
157
+ end_time = time.time()
158
+ run_time = end_time - start_time
159
+ return metrics, run_time
160
+
161
+
162
+ with gr.Blocks() as app:
163
+ with gr.Tab("🧬 PINDER evaluation template"):
164
+ with gr.Row():
165
+ with gr.Column():
166
+ input_system_id_pinder = gr.Textbox(label="PINDER system ID")
167
+ input_receptor_file_pinder = gr.File(label="Receptor file")
168
+ input_ligand_file_pinder = gr.File(label="Ligand file")
169
+ methodname_pinder = gr.Textbox(label="Name of your method in the format mlsb/spacename")
170
+ store_pinder = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
171
+ eval_btn_pinder = gr.Button("Run Evaluation")
172
+
173
+
174
+
175
+
176
+ with gr.Tab("⚖️ PLINDER evaluation template"):
177
+ with gr.Row():
178
+ with gr.Column():
179
+ input_system_id = gr.Textbox(label="PLINDER system ID")
180
+ input_receptor_file = gr.File(label="Receptor file (CIF)")
181
+ input_ligand_file = gr.File(label="Ligand file (SDF)")
182
+ flexible = gr.Checkbox(label="Flexible docking", value=True)
183
+ posebusters = gr.Checkbox(label="PoseBusters", value=True)
184
+ methodname = gr.Textbox(label="Name of your method in the format mlsb/spacename")
185
+ store = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
186
+
187
+ eval_btn = gr.Button("Run Evaluation")
188
+ gr.Examples(
189
+ [
190
+ [
191
+ "4neh__1__1.B__1.H",
192
+ "input_protein_test.cif",
193
+ "input_ligand_test.sdf",
194
+ True,
195
+ True,
196
+ ],
197
+ ],
198
+ [input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters, methodname, store],
199
+ )
200
+ eval_run_time = gr.Textbox(label="Evaluation runtime")
201
+ metric_table = gr.DataFrame(
202
+ pd.DataFrame([], columns=EVAL_METRICS), label="Evaluation metrics"
203
+ )
204
+
205
+ metric_table_pinder = gr.DataFrame(
206
+ pd.DataFrame([], columns=EVAL_METRICS_PINDER), label="Evaluation metrics"
207
+ )
208
+
209
+ eval_btn.click(
210
+ get_metrics,
211
+ inputs=[input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters],
212
+ outputs=[metric_table, eval_run_time],
213
+ )
214
+ eval_btn_pinder.click(
215
+ get_metrics_pinder,
216
+ inputs=[input_system_id_pinder, input_receptor_file_pinder, input_ligand_file_pinder, methodname_pinder, store_pinder],
217
+ outputs=[metric_table_pinder, eval_run_time],
218
+ )
219
+
220
+ app.launch()