Rocco Meli commited on
Commit
7b26682
·
1 Parent(s): f16c560

add inference with gnina-torch

Browse files
Files changed (5) hide show
  1. app.py +88 -18
  2. html/ligand.html +1 -1
  3. html/pl.html +1 -1
  4. html/wrapper.html +10 -0
  5. requirements.txt +2 -1
app.py CHANGED
@@ -2,10 +2,12 @@ import gradio as gr
2
 
3
  import os
4
 
 
5
  def load_html(html_file: str):
6
  with open(os.path.join("html", html_file), "r") as f:
7
  return f.read()
8
 
 
9
  def load_protein_from_file(protein_file) -> str:
10
  """
11
  Parameters
@@ -21,21 +23,22 @@ def load_protein_from_file(protein_file) -> str:
21
  with open(protein_file.name, "r") as f:
22
  return f.read()
23
 
 
24
  def load_ligand_from_file(ligand_file):
25
  with open(ligand_file.name, "r") as f:
26
  return f.read()
27
-
 
28
  def protein_html_from_file(protein_file):
29
  protein = load_protein_from_file(protein_file)
30
  protein_html = load_html("protein.html")
31
 
32
  html = protein_html.replace("%%%PDB%%%", protein)
33
 
34
- return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
35
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
36
- allow-scripts allow-same-origin allow-popups
37
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
38
- allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
39
 
40
  def ligand_html_from_file(ligand_file):
41
  ligand = load_ligand_from_file(ligand_file)
@@ -43,11 +46,10 @@ def ligand_html_from_file(ligand_file):
43
 
44
  html = ligand_html.replace("%%%SDF%%%", ligand)
45
 
46
- return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
47
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
48
- allow-scripts allow-same-origin allow-popups
49
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
50
- allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
51
 
52
  def protein_ligand_html_from_file(protein_file, ligand_file):
53
  protein = load_protein_from_file(protein_file)
@@ -57,11 +59,70 @@ def protein_ligand_html_from_file(protein_file, ligand_file):
57
  html = protein_ligand_html.replace("%%%PDB%%%", protein)
58
  html = html.replace("%%%SDF%%%", ligand)
59
 
60
- return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
61
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
62
- allow-scripts allow-same-origin allow-popups
63
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
64
- allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  demo = gr.Blocks()
67
 
@@ -82,12 +143,21 @@ with demo:
82
  ligand = gr.HTML()
83
  lbtn.click(fn=ligand_html_from_file, inputs=[lfile], outputs=ligand)
84
 
 
85
  with gr.Row():
86
- gr.Markdown("# Protein-Ligand Complex")
87
  plcomplex = gr.HTML()
88
 
89
  # TODO: Automatically display complex when both files are uploaded
90
  plbtn = gr.Button("View")
91
- plbtn.click(fn=protein_ligand_html_from_file, inputs=[pfile, lfile], outputs=plcomplex)
 
 
 
 
 
 
 
 
 
92
 
93
  demo.launch()
 
2
 
3
  import os
4
 
5
+
6
  def load_html(html_file: str):
7
  with open(os.path.join("html", html_file), "r") as f:
8
  return f.read()
9
 
10
+
11
  def load_protein_from_file(protein_file) -> str:
12
  """
13
  Parameters
 
23
  with open(protein_file.name, "r") as f:
24
  return f.read()
25
 
26
+
27
  def load_ligand_from_file(ligand_file):
28
  with open(ligand_file.name, "r") as f:
29
  return f.read()
30
+
31
+
32
  def protein_html_from_file(protein_file):
33
  protein = load_protein_from_file(protein_file)
34
  protein_html = load_html("protein.html")
35
 
36
  html = protein_html.replace("%%%PDB%%%", protein)
37
 
38
+ wrapper = load_html("wrapper.html")
39
+
40
+ return wrapper.replace("%%%HTML%%%", html)
41
+
 
42
 
43
  def ligand_html_from_file(ligand_file):
44
  ligand = load_ligand_from_file(ligand_file)
 
46
 
47
  html = ligand_html.replace("%%%SDF%%%", ligand)
48
 
49
+ wrapper = load_html("wrapper.html")
50
+
51
+ return wrapper.replace("%%%HTML%%%", html)
52
+
 
53
 
54
  def protein_ligand_html_from_file(protein_file, ligand_file):
55
  protein = load_protein_from_file(protein_file)
 
59
  html = protein_ligand_html.replace("%%%PDB%%%", protein)
60
  html = html.replace("%%%SDF%%%", ligand)
61
 
62
+ wrapper = load_html("wrapper.html")
63
+
64
+ return wrapper.replace("%%%HTML%%%", html)
65
+
66
+
67
+ def predict(protein_file, ligand_file, cnn="default"):
68
+ import molgrid
69
+ from gninatorch import gnina, dataloaders
70
+ import torch
71
+ import pandas as pd
72
+
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ print(device)
75
+
76
+ model, ensemble = gnina.setup_gnina_model(cnn, 23.5, 0.5)
77
+ model.eval()
78
+ model.to(device)
79
+
80
+ example_provider = molgrid.ExampleProvider(
81
+ data_root="",
82
+ balanced=False,
83
+ shuffle=False,
84
+ default_batch_size=1,
85
+ iteration_scheme=molgrid.IterationScheme.SmallEpoch,
86
+ )
87
+
88
+ with open("data.in", "w") as f:
89
+ f.write(protein_file.name)
90
+ f.write(" ")
91
+ f.write(ligand_file.name)
92
+
93
+ print("Populating example provider... ", end="")
94
+ example_provider.populate("data.in")
95
+ print("done")
96
+
97
+ grid_maker = molgrid.GridMaker(resolution=0.5, dimension=23.5)
98
+
99
+ # TODO: Allow average over different rotations
100
+ loader = dataloaders.GriddedExamplesLoader(
101
+ example_provider=example_provider,
102
+ grid_maker=grid_maker,
103
+ random_translation=0.0, # No random translations for inference
104
+ random_rotation=False, # No random rotations for inference
105
+ grids_only=True,
106
+ device=device,
107
+ )
108
+
109
+ print("Loading and gridding data... ", end="")
110
+ batch = next(loader)
111
+ print("done")
112
+
113
+ print("Predicting... ", end="")
114
+ with torch.no_grad():
115
+ log_pose, affinity, affinity_var = model(batch)
116
+ print("done")
117
+
118
+ return pd.DataFrame(
119
+ {
120
+ "CNNscore": [torch.exp(log_pose[:, -1]).item()],
121
+ "CNNaffinity": [affinity.item()],
122
+ "CNNvariance": [affinity_var.item()],
123
+ }
124
+ )
125
+
126
 
127
  demo = gr.Blocks()
128
 
 
143
  ligand = gr.HTML()
144
  lbtn.click(fn=ligand_html_from_file, inputs=[lfile], outputs=ligand)
145
 
146
+ gr.Markdown("# Protein-Ligand Complex")
147
  with gr.Row():
 
148
  plcomplex = gr.HTML()
149
 
150
  # TODO: Automatically display complex when both files are uploaded
151
  plbtn = gr.Button("View")
152
+ plbtn.click(
153
+ fn=protein_ligand_html_from_file, inputs=[pfile, lfile], outputs=plcomplex
154
+ )
155
+
156
+ gr.Markdown("# Gnina-Torch")
157
+ with gr.Row():
158
+ df = gr.Dataframe()
159
+ btn = gr.Button("Score!")
160
+ btn.click(fn=predict, inputs=[pfile, lfile], outputs=df)
161
+
162
 
163
  demo.launch()
html/ligand.html CHANGED
@@ -26,7 +26,7 @@
26
  let config = { backgroundColor: "white" };
27
  let viewer = $3Dmol.createViewer(element, config);
28
  viewer.addModel(sdf, "sdf");
29
- viewer.getModel(0).setStyle({}, { "stick": { "colorscheme": "lightgreyCarbon" } });
30
  viewer.zoomTo();
31
  viewer.render();
32
  viewer.zoom(0.8, 2000);
 
26
  let config = { backgroundColor: "white" };
27
  let viewer = $3Dmol.createViewer(element, config);
28
  viewer.addModel(sdf, "sdf");
29
+ viewer.getModel(0).setStyle({}, { "stick": { "colorscheme": "purpleCarbon" } });
30
  viewer.zoomTo();
31
  viewer.render();
32
  viewer.zoom(0.8, 2000);
html/pl.html CHANGED
@@ -29,7 +29,7 @@
29
  viewer.addModel(pdb, "pdb");
30
  viewer.addModel(sdf, "sdf");
31
  viewer.getModel(0).setStyle({}, { cartoon: { colorscheme: "whiteCarbon" } });
32
- viewer.getModel(1).setStyle({}, { stick: { colorscheme: "lightgreyCarbon" } });
33
  viewer.zoomTo();
34
  viewer.render();
35
  viewer.zoom(0.8, 2000);
 
29
  viewer.addModel(pdb, "pdb");
30
  viewer.addModel(sdf, "sdf");
31
  viewer.getModel(0).setStyle({}, { cartoon: { colorscheme: "whiteCarbon" } });
32
+ viewer.getModel(1).setStyle({}, { stick: { colorscheme: "purpleCarbon" } });
33
  viewer.zoomTo();
34
  viewer.render();
35
  viewer.zoom(0.8, 2000);
html/wrapper.html ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <iframe style="width: 100%; height: 600px"
2
+ name="result"
3
+ allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
4
+ sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups allow-top-navigation-by-user-activation allow-downloads"
5
+ allowfullscreen=""
6
+ allowpaymentrequest=""
7
+ frameborder="0"
8
+ srcdoc='%%%HTML%%%'
9
+ >
10
+ </iframe>
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- gninatorch==0.0.2
 
 
1
+ gninatorch==0.0.2
2
+ pandas>=1.0