mgyigit commited on
Commit
e3d7930
·
1 Parent(s): 25a5f8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -40
app.py CHANGED
@@ -71,7 +71,7 @@ model_configs = {
71
  def function(model_name: str, num_molecules: int, seed_num: int):
72
  '''
73
  Returns:
74
- image, score_df, file_path, and individual metrics
75
  '''
76
  if model_name == "DrugGEN-NoTarget":
77
  model_name = "NoTarget"
@@ -90,10 +90,9 @@ def function(model_name: str, num_molecules: int, seed_num: int):
90
  except ValueError:
91
  raise gr.Error("The seed must be an integer value!")
92
 
93
-
94
  inferer = Inference(config)
95
  start_time = time.time()
96
- scores = inferer.inference() # create scores_df out of this
97
  et = time.time() - start_time
98
 
99
  score_df = pd.DataFrame({
@@ -166,7 +165,7 @@ def function(model_name: str, num_molecules: int, seed_num: int):
166
  highlightBondLists=None,
167
  )
168
 
169
- return molecule_image, score_df, new_path, basic_metrics, advanced_metrics
170
 
171
 
172
 
@@ -283,12 +282,14 @@ For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
283
  with gr.Row():
284
  with gr.Column():
285
  basic_metrics_df = gr.Dataframe(
 
286
  headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", "Drug Novelty", "Runtime (s)"],
287
  elem_id="basic-metrics"
288
  )
289
 
290
  with gr.Column():
291
  advanced_metrics_df = gr.Dataframe(
 
292
  headers=["QED", "SA Score", "Internal Diversity", "SNN ChEMBL", "SNN Drug", "Max Length"],
293
  elem_id="advanced-metrics"
294
  )
@@ -301,26 +302,7 @@ For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
301
  file_download = gr.File(
302
  label="Download All Generated Molecules (SMILES format)",
303
  )
304
-
305
- with gr.Group(elem_id="metrics-container"):
306
- gr.Markdown("### Performance Metrics")
307
-
308
- with gr.Row():
309
- with gr.Column():
310
- validity = gr.Number(label="Validity", precision=3)
311
- uniqueness = gr.Number(label="Uniqueness", precision=3)
312
- novelty_train = gr.Number(label="Novelty (Train)", precision=3)
313
- novelty_test = gr.Number(label="Novelty (Test)", precision=3)
314
- drug_novelty = gr.Number(label="Drug Novelty", precision=3)
315
- runtime = gr.Number(label="Runtime (seconds)", precision=2)
316
-
317
- with gr.Column():
318
- qed = gr.Number(label="QED Score", precision=3, info="Higher is more drug-like (0-1)")
319
- sa = gr.Number(label="SA Score", precision=3, info="Lower is easier to synthesize (1-10)")
320
- int_div = gr.Number(label="Internal Diversity", precision=3)
321
- snn_chembl = gr.Number(label="SNN ChEMBL", precision=3)
322
- snn_drug = gr.Number(label="SNN Drug", precision=3)
323
- max_len = gr.Number(label="Max Length", precision=3)
324
 
325
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
326
 
@@ -329,24 +311,12 @@ For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
329
  inputs=[model_name, num_molecules, seed_num],
330
  outputs=[
331
  image_output,
332
- scores_df,
333
  file_download,
334
- validity,
335
- uniqueness,
336
- novelty_train,
337
- novelty_test,
338
- drug_novelty,
339
- runtime,
340
- qed,
341
- sa,
342
- int_div,
343
- snn_chembl,
344
- snn_drug,
345
- max_len
346
  ],
347
  api_name="inference"
348
  )
349
-
350
  demo.queue()
351
- demo.launch()
352
-
 
71
  def function(model_name: str, num_molecules: int, seed_num: int):
72
  '''
73
  Returns:
74
+ image, metrics_df, file_path, basic_metrics, advanced_metrics
75
  '''
76
  if model_name == "DrugGEN-NoTarget":
77
  model_name = "NoTarget"
 
90
  except ValueError:
91
  raise gr.Error("The seed must be an integer value!")
92
 
 
93
  inferer = Inference(config)
94
  start_time = time.time()
95
+ scores = inferer.inference() # This returns a DataFrame with specific columns
96
  et = time.time() - start_time
97
 
98
  score_df = pd.DataFrame({
 
165
  highlightBondLists=None,
166
  )
167
 
168
+ return molecule_image, new_path, basic_metrics, advanced_metrics
169
 
170
 
171
 
 
282
  with gr.Row():
283
  with gr.Column():
284
  basic_metrics_df = gr.Dataframe(
285
+ label="Basic Metrics",
286
  headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", "Drug Novelty", "Runtime (s)"],
287
  elem_id="basic-metrics"
288
  )
289
 
290
  with gr.Column():
291
  advanced_metrics_df = gr.Dataframe(
292
+ label="Advanced Metrics",
293
  headers=["QED", "SA Score", "Internal Diversity", "SNN ChEMBL", "SNN Drug", "Max Length"],
294
  elem_id="advanced-metrics"
295
  )
 
302
  file_download = gr.File(
303
  label="Download All Generated Molecules (SMILES format)",
304
  )
305
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
308
 
 
311
  inputs=[model_name, num_molecules, seed_num],
312
  outputs=[
313
  image_output,
 
314
  file_download,
315
+ basic_metrics_df,
316
+ advanced_metrics_df
 
 
 
 
 
 
 
 
 
 
317
  ],
318
  api_name="inference"
319
  )
320
+ #demo.queue(concurrency_count=1)
321
  demo.queue()
322
+ demo.launch()