Eachan Johnson commited on
Commit
a11f59c
·
1 Parent(s): 7dbc2af

Update examples and add second plot

Browse files
app.py CHANGED
@@ -23,6 +23,7 @@ from schemist.tables import converter
23
  import torch
24
 
25
  CACHE = "./cache"
 
26
  HEADER_FILE = os.path.join("sources", "header.md")
27
  MODEL_REPOS = {
28
  "Klebsiella pneumoniae": "hf://scbirlab/spark-dv-fp-2503-kpn",
@@ -78,6 +79,10 @@ def convert_one(
78
  input_representation: str = 'smiles',
79
  output_representation: Union[Iterable[str], str] = 'smiles'
80
  ):
 
 
 
 
81
 
82
  df = pd.DataFrame({
83
  input_representation: _clean_split_input(strings),
@@ -91,23 +96,17 @@ def convert_one(
91
  )
92
 
93
 
94
- def predict_one(
95
- strings: str,
96
- input_representation: str = 'smiles',
97
  predict: Union[Iterable[str], str] = 'smiles',
98
  extra_metrics: Optional[Union[Iterable[str], str]] = None
99
- ):
 
 
100
  if extra_metrics is None:
101
  extra_metrics = []
102
  else:
103
  extra_metrics = cast(extra_metrics, to=list)
104
- prediction_df = convert_one(
105
- strings=strings,
106
- input_representation=input_representation,
107
- output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
108
- )
109
- species_to_predict = cast(predict, to=list)
110
- prediction_cols = []
111
  for species in species_to_predict:
112
  message = f"Predicting for species: {species}"
113
  print_err(message)
@@ -116,7 +115,7 @@ def predict_one(
116
  this_features = this_modelbox._input_cols
117
  this_labels = this_modelbox._label_cols
118
  this_prediction_input = (
119
- prediction_df
120
  .rename(columns={
121
  "smiles": this_features[0],
122
  })
@@ -132,10 +131,10 @@ def predict_one(
132
  ).with_format("numpy")["__prediction__"].flatten()
133
  print(prediction)
134
  this_col = f"{species}: predicted MIC (µM)"
135
- prediction_df[this_col] = np.power(10., -prediction) * 1e6
136
  prediction_cols.append(this_col)
137
  this_col = f"{species}: predicted MIC (µg / mL)"
138
- prediction_df[this_col] = np.power(10., -prediction) * 1e3 * prediction_df["mwt"]
139
  prediction_cols.append(this_col)
140
 
141
  for extra_metric in extra_metrics:
@@ -155,10 +154,33 @@ def predict_one(
155
  )
156
  .with_format("numpy")
157
  )
158
- prediction_df[this_col] = this_extra[this_extra.column_names[-1]]
159
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  return gr.DataFrame(
161
- prediction_df[['id', 'pubchem_name', 'pubchem_id'] + prediction_cols + ['smiles', 'inchikey', "mwt", "clogp"]],
 
 
 
 
162
  visible=True
163
  )
164
 
@@ -209,70 +231,38 @@ def predict_file(
209
  else:
210
  extra_metrics = cast(extra_metrics, to=list)
211
 
 
 
 
 
 
 
212
  prediction_df = convert_file(
213
  df,
214
  column=column,
215
  input_representation=input_representation,
216
  output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
217
  )
218
- if prediction_df.shape[0] > 1000:
219
- message = f"Truncating input to 1000 rows"
220
- print_err(message)
221
- gr.Info(message, duration=15)
222
- prediction_df = prediction_df.iloc[:1000]
223
- species_to_predict = cast(predict, to=list)
224
- prediction_cols = []
225
- for species in species_to_predict:
226
- message = f"Predicting for species: {species}"
227
- print_err(message)
228
- gr.Info(message, duration=3)
229
- this_modelbox = MODELBOXES[species]
230
- this_features = this_modelbox._input_cols
231
- this_labels = this_modelbox._label_cols
232
- this_prediction_input = (
233
- prediction_df
234
- .rename(columns={
235
- "smiles": this_features[0],
236
- })
237
- .assign(**{label: np.nan for label in this_labels})
238
- )
239
- print(this_prediction_input)
240
- prediction = this_modelbox.predict(
241
- data=this_prediction_input,
242
- features=this_features,
243
- labels=this_labels,
244
- aggregator="mean",
245
- cache=CACHE,
246
- ).with_format("numpy")["__prediction__"].flatten()
247
- print(prediction)
248
- this_col = f"{species}: predicted MIC (µM)"
249
- prediction_df[this_col] = np.power(10., -prediction) * 1e6
250
- prediction_cols.append(this_col)
251
- this_col = f"{species}: predicted MIC (µg / mL)"
252
- prediction_df[this_col] = np.power(10., -prediction) * 1e3 * prediction_df["mwt"]
253
- prediction_cols.append(this_col)
254
-
255
- for extra_metric in extra_metrics:
256
- message = f"Calculating {extra_metric} for species: {species}"
257
- print_err(message)
258
- gr.Info(message, duration=10)
259
- # this_modelbox._input_training_data = this_modelbox._input_training_data.remove_columns([this_modelbox._in_key])
260
- this_col = f"{species}: {extra_metric}"
261
- prediction_cols.append(this_col)
262
- print(">>>", this_modelbox._input_training_data)
263
- print(">>>", this_modelbox._input_training_data.format)
264
- print(">>>", this_modelbox._in_key, this_modelbox._out_key)
265
- this_extra = (
266
- EXTRA_METRICS[extra_metric](
267
- this_modelbox,
268
- this_prediction_input,
269
- )
270
- .with_format("numpy")
271
- )
272
- prediction_df[this_col] = this_extra[this_extra.column_names[-1]]
273
- other_cols = [col for col in prediction_df if col not in ['id', 'inchikey', 'smiles', "mwt", "clogp"] + [column] + prediction_cols]
274
-
275
- return prediction_df[['id', 'inchikey'] + [column] + prediction_cols + other_cols + ['smiles', "mwt", "clogp"]]
276
 
277
  def draw_one(
278
  strings: Union[Iterable[str], str],
@@ -293,31 +283,35 @@ def draw_one(
293
  legends=["\n".join(items) for items in zip(*_ids.values())],
294
  )
295
 
 
 
 
 
 
 
 
296
 
297
- def plot_pred_vs_observed(
 
298
  df,
299
- species: str,
300
- observed: str,
301
  color: Optional[str] = None,
302
  ):
303
  print_err(df.head())
304
- xcol = f"{species}: predicted MIC (µM)"
305
- ycol = observed
306
- y_title = f"Observed ({ycol})"
307
- cols = ["id", "inchikey", "smiles", "mwt", "clogp", xcol, ycol]
308
- color_title = color
309
  if color is not None and color not in cols:
310
  cols.append(color)
311
  cols = list(set(cols))
312
- print_err(df[cols].columns)
313
- if np.all(df[xcol] > 0):
314
- df[xcol] = np.log10(df[xcol])
315
- x_title = f"Predicted log10[MIC(µM)]"
316
 
317
  return gr.ScatterPlot(
318
  value=df[cols],
319
- x=xcol,
320
- y=ycol,
321
  color=color,
322
  x_title=x_title,
323
  y_title=y_title,
@@ -327,14 +321,32 @@ def plot_pred_vs_observed(
327
  )
328
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  def download_table(
331
  df: pd.DataFrame
332
  ) -> str:
333
  df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
334
- filename = f"converted-{df_hash}.csv"
335
  df.to_csv(filename, index=False)
336
  return gr.DownloadButton(value=filename, visible=True)
337
 
 
338
  with gr.Blocks() as demo:
339
 
340
  with open(HEADER_FILE, 'r') as f:
@@ -379,7 +391,7 @@ with gr.Blocks() as demo:
379
  ]),
380
  list(MODEL_REPOS)[0],
381
  list(EXTRA_METRICS)[:2],
382
- ], # cipro, ceftriaxone, cefiderocol, linezolid, gepotidacin
383
  [
384
  '\n'.join([
385
  "C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
@@ -399,6 +411,7 @@ with gr.Blocks() as demo:
399
  "COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
400
  "CC1=CC(=NO1)NS(=O)(=O)C2=CC=C(C=C2)N",
401
  "C1[C@@H]([C@H]([C@@H]([C@H]([C@@H]1NC(=O)[C@H](CCN)O)O[C@@H]2[C@@H]([C@H]([C@@H]([C@H](O2)CO)O)N)O)O)O[C@@H]3[C@@H]([C@H]([C@@H]([C@H](O3)CN)O)O)O)N\nC1=CN=CC=C1C(=O)NN",
 
402
  ]),
403
  list(MODEL_REPOS)[0],
404
  list(EXTRA_METRICS)[:2],
@@ -420,10 +433,37 @@ with gr.Blocks() as demo:
420
  "CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)CC4)N=C3",
421
  "CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@@H](C4)N)N=C3",
422
  "CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@H](CC4)[NH3+])N=C3.[Cl-]",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  ]),
424
  list(MODEL_REPOS)[0],
425
  list(EXTRA_METRICS)[:2],
426
- ], # Debio1452, Debio-1452-NH3, Fabimycin,
427
 
428
  ],
429
  example_labels=[
@@ -431,8 +471,9 @@ with gr.Blocks() as demo:
431
  "Doxorubicin, Ampicillin, Amoxicillin, Meropenem, Tetracycline, Anhydrotetracycline",
432
  "Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid",
433
  "Murepavadin, Vancomycin, Zosurabalpin, Plazomicin, Gentamicin, Rifampicin",
434
- "Debio-1452, Debio-1452-NH3, Fabimycin",
435
-
 
436
  ],
437
  inputs=[input_line, output_species_single, extra_metric],
438
  cache_mode="eager",
@@ -476,7 +517,7 @@ with gr.Blocks() as demo:
476
  outputs=download_single
477
  )
478
 
479
- with gr.Tab("Predict on structures from a file (max. 1000 rows, single species)"):
480
  input_file = gr.File(
481
  label="Upload a table of chemical compounds here",
482
  file_types=[".xlsx", ".csv", ".tsv", ".txt"],
@@ -524,14 +565,36 @@ with gr.Blocks() as demo:
524
  )
525
  with gr.Row():
526
  observed_col = gr.Dropdown(
527
- label="Observed column (y-axis) for comparison plot",
528
  choices=[],
529
  value=None,
530
  interactive=True,
531
  visible=False,
532
  )
533
  color_col = gr.Dropdown(
534
- label="Color for comparison plot",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  choices=[],
536
  value=None,
537
  interactive=True,
@@ -544,38 +607,65 @@ with gr.Blocks() as demo:
544
  file_examples = gr.Examples(
545
  examples=[
546
  [
547
- "example-data/stokes2020-eco-1000.csv",
548
  "SMILES",
549
  "Klebsiella pneumoniae",
550
  "Mean_Inhibition",
551
  "Klebsiella pneumoniae: Doubtscore",
552
- list(EXTRA_METRICS)[:3]],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  ],
554
  example_labels=[
555
- "Stokes J. et al., Cell, 2020",
 
 
556
  ],
557
  inputs=[input_file, input_column, output_species, observed_col, color_col, extra_metric_file],
558
  cache_mode="eager",
559
  )
560
- pred_vs_observed = gr.ScatterPlot(
561
- label="Prediction vs observed",
562
- x_title="Predicted MIC (µM)",
563
- y_title="Observed",
564
- visible=False,
565
- height=600,
566
- )
 
 
 
 
 
 
 
 
 
 
 
 
567
 
568
  file_examples.load_input_event.then(
569
- load_input_data,
570
- inputs=[input_file],
571
- outputs=[input_data, input_column],
572
  )
573
  input_file.upload(
574
- load_input_data,
575
- inputs=[input_file],
576
- outputs=[input_data, input_column]
577
  )
578
- go_button2.click(
579
  predict_file,
580
  inputs=[
581
  input_data,
@@ -591,18 +681,17 @@ with gr.Blocks() as demo:
591
  download_table,
592
  inputs=input_data,
593
  outputs=download
594
- ).then(
595
- partial(get_dropdown_options, _type="number"),
596
- inputs=[input_data],
597
- outputs=[observed_col],
598
- ).then(
599
- partial(get_dropdown_options, _type="number"),
600
- inputs=[input_data],
601
- outputs=[color_col],
602
  ).then(
603
  lambda: gr.Button(visible=True),
604
- outputs=[plot_button],
605
  )
 
 
 
 
 
 
 
606
 
607
  plot_button.click(
608
  plot_pred_vs_observed,
@@ -612,7 +701,16 @@ with gr.Blocks() as demo:
612
  observed_col,
613
  color_col,
614
  ],
615
- outputs=pred_vs_observed,
 
 
 
 
 
 
 
 
 
616
  )
617
 
618
  if __name__ == "__main__":
 
23
  import torch
24
 
25
  CACHE = "./cache"
26
+ MAX_ROWS = 4000
27
  HEADER_FILE = os.path.join("sources", "header.md")
28
  MODEL_REPOS = {
29
  "Klebsiella pneumoniae": "hf://scbirlab/spark-dv-fp-2503-kpn",
 
79
  input_representation: str = 'smiles',
80
  output_representation: Union[Iterable[str], str] = 'smiles'
81
  ):
82
+ output_representation = cast(output_representation, to=list)
83
+ for rep in output_representation:
84
+ message = f"Converting from {input_representation} to {rep}..."
85
+ gr.Info(message, duration=10)
86
 
87
  df = pd.DataFrame({
88
  input_representation: _clean_split_input(strings),
 
96
  )
97
 
98
 
99
+ def _prediction_loop(
100
+ df: pd.DataFrame,
 
101
  predict: Union[Iterable[str], str] = 'smiles',
102
  extra_metrics: Optional[Union[Iterable[str], str]] = None
103
+ ) -> pd.DataFrame:
104
+ species_to_predict = cast(predict, to=list)
105
+ prediction_cols = []
106
  if extra_metrics is None:
107
  extra_metrics = []
108
  else:
109
  extra_metrics = cast(extra_metrics, to=list)
 
 
 
 
 
 
 
110
  for species in species_to_predict:
111
  message = f"Predicting for species: {species}"
112
  print_err(message)
 
115
  this_features = this_modelbox._input_cols
116
  this_labels = this_modelbox._label_cols
117
  this_prediction_input = (
118
+ df
119
  .rename(columns={
120
  "smiles": this_features[0],
121
  })
 
131
  ).with_format("numpy")["__prediction__"].flatten()
132
  print(prediction)
133
  this_col = f"{species}: predicted MIC (µM)"
134
+ df[this_col] = np.power(10., -prediction) * 1e6
135
  prediction_cols.append(this_col)
136
  this_col = f"{species}: predicted MIC (µg / mL)"
137
+ df[this_col] = np.power(10., -prediction) * 1e3 * df["mwt"]
138
  prediction_cols.append(this_col)
139
 
140
  for extra_metric in extra_metrics:
 
154
  )
155
  .with_format("numpy")
156
  )
157
+ df[this_col] = this_extra[this_extra.column_names[-1]]
158
+
159
+ return prediction_cols, df
160
+
161
+
162
+ def predict_one(
163
+ strings: str,
164
+ input_representation: str = 'smiles',
165
+ predict: Union[Iterable[str], str] = 'smiles',
166
+ extra_metrics: Optional[Union[Iterable[str], str]] = None
167
+ ):
168
+ prediction_df = convert_one(
169
+ strings=strings,
170
+ input_representation=input_representation,
171
+ output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
172
+ )
173
+ prediction_cols, prediction_df = _prediction_loop(
174
+ prediction_df,
175
+ predict=predict,
176
+ extra_metrics=extra_metrics,
177
+ )
178
  return gr.DataFrame(
179
+ prediction_df[
180
+ ['id', 'pubchem_name', 'pubchem_id']
181
+ + prediction_cols
182
+ + ['smiles', 'inchikey', "mwt", "clogp"]
183
+ ],
184
  visible=True
185
  )
186
 
 
231
  else:
232
  extra_metrics = cast(extra_metrics, to=list)
233
 
234
+ if df.shape[0] > MAX_ROWS:
235
+ message = f"Truncating input to {MAX_ROWS} rows"
236
+ print_err(message)
237
+ gr.Info(message, duration=15)
238
+ df = df.iloc[:MAX_ROWS]
239
+
240
  prediction_df = convert_file(
241
  df,
242
  column=column,
243
  input_representation=input_representation,
244
  output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
245
  )
246
+ prediction_cols, prediction_df = _prediction_loop(
247
+ prediction_df,
248
+ predict=predict,
249
+ extra_metrics=extra_metrics,
250
+ )
251
+ main_cols = set(
252
+ ['id', 'inchikey', 'smiles', "mwt", "clogp"]
253
+ + [column]
254
+ + prediction_cols
255
+ )
256
+ other_cols = [
257
+ col for col in prediction_df
258
+ if col not in main_cols
259
+ ]
260
+ return prediction_df[
261
+ ['id', 'inchikey']
262
+ + [column]
263
+ + prediction_cols + other_cols
264
+ + ['smiles', "mwt", "clogp"]
265
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  def draw_one(
268
  strings: Union[Iterable[str], str],
 
283
  legends=["\n".join(items) for items in zip(*_ids.values())],
284
  )
285
 
286
+ def log10_if_all_positive(df, col):
287
+ if np.all(df[col] > 0.):
288
+ df[col] = np.log10(df[col])
289
+ title = f"log10[ {col} ]"
290
+ else:
291
+ title = col
292
+ return title, df
293
 
294
+
295
+ def plot_x_vs_y(
296
  df,
297
+ x: str,
298
+ y: str,
299
  color: Optional[str] = None,
300
  ):
301
  print_err(df.head())
302
+ y_title = y
303
+ cols = ["id", "inchikey", "smiles", "mwt", "clogp", x, y]
 
 
 
304
  if color is not None and color not in cols:
305
  cols.append(color)
306
  cols = list(set(cols))
307
+ x_title, df = log10_if_all_positive(df, x)
308
+ y_title, df = log10_if_all_positive(df, y)
309
+ color_title, df = log10_if_all_positive(df, color)
 
310
 
311
  return gr.ScatterPlot(
312
  value=df[cols],
313
+ x=x,
314
+ y=y,
315
  color=color,
316
  x_title=x_title,
317
  y_title=y_title,
 
321
  )
322
 
323
 
324
+ def plot_pred_vs_observed(
325
+ df,
326
+ species: str,
327
+ observed: str,
328
+ color: Optional[str] = None,
329
+ ):
330
+ print_err(df.head())
331
+ xcol = f"{species}: predicted MIC (µM)"
332
+ ycol = observed
333
+ return plot_x_vs_y(
334
+ df,
335
+ x=xcol,
336
+ y=ycol,
337
+ color=color,
338
+ )
339
+
340
+
341
  def download_table(
342
  df: pd.DataFrame
343
  ) -> str:
344
  df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
345
+ filename = f"predicted-{df_hash}.csv"
346
  df.to_csv(filename, index=False)
347
  return gr.DownloadButton(value=filename, visible=True)
348
 
349
+
350
  with gr.Blocks() as demo:
351
 
352
  with open(HEADER_FILE, 'r') as f:
 
391
  ]),
392
  list(MODEL_REPOS)[0],
393
  list(EXTRA_METRICS)[:2],
394
+ ], # cipro, ceftriaxone, cefiderocol, linezolid, gepotidacin
395
  [
396
  '\n'.join([
397
  "C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
 
411
  "COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
412
  "CC1=CC(=NO1)NS(=O)(=O)C2=CC=C(C=C2)N",
413
  "C1[C@@H]([C@H]([C@@H]([C@H]([C@@H]1NC(=O)[C@H](CCN)O)O[C@@H]2[C@@H]([C@H]([C@@H]([C@H](O2)CO)O)N)O)O)O[C@@H]3[C@@H]([C@H]([C@@H]([C@H](O3)CN)O)O)O)N\nC1=CN=CC=C1C(=O)NN",
414
+ "C1=CN=CC=C1C(=O)NN ",
415
  ]),
416
  list(MODEL_REPOS)[0],
417
  list(EXTRA_METRICS)[:2],
 
433
  "CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)CC4)N=C3",
434
  "CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@@H](C4)N)N=C3",
435
  "CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@H](CC4)[NH3+])N=C3.[Cl-]",
436
+ "C1=C(C(=O)NC(=O)N1)F",
437
+ "CCCCCCNC(=O)N1C=C(C(=O)NC1=O)F",
438
+ "C[C@@H]1OC[C@@H]2[C@@H](O1)[C@@H]([C@H]([C@@H](O2)O[C@H]3[C@H]4COC(=O)[C@@H]4[C@@H](C5=CC6=C(C=C35)OCO6)C7=CC(=C(C(=C7)OC)O)OC)O)O",
439
+ ]),
440
+ list(MODEL_REPOS)[0],
441
+ list(EXTRA_METRICS)[:2],
442
+ ], # Debio1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide
443
+ [
444
+ '\n'.join([
445
+ "COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
446
+ "CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
447
+ "C1=CC(=CC=C1CCC2=CNC3=C2C(=O)NC(=N3)N)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
448
+ "CC1=C(C2=C(C=C1)N=C(NC2=O)N)SC3=CC=NC=C3",
449
+ "CN(CC1=CN=C2C(=N1)C(=NC(=N2)N)N)C3=CC=C(C=C3)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
450
+ "CC1=NC2=C(C=C(C=C2)CN(C)C3=CC=C(S3)C(=O)N[C@@H](CCC(=O)O)C(=O)O)C(=O)N1",
451
+ ]),
452
+ list(MODEL_REPOS)[0],
453
+ list(EXTRA_METRICS)[:2],
454
+ ], # Trimethoprim, SCH79797, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed
455
+ [
456
+ '\n'.join([
457
+ "C[C@H]([C@@H](C(=O)NO)NC(=O)C1=CC=C(C=C1)C#CC2=CC=C(C=C2)CN3CCOCC3)O",
458
+ "CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
459
+ "C1=CC=C(C=C1)CNC2=NC(=NC3=CC=CC=C32)NCC4=CC=CC=C4",
460
+ "CC(C)(C)C1=CC=C(C=C1)C(=O)NC(=S)NC2=CC=C(C=C2)NC(=O)CCCCN(C)C",
461
+ "CCC1=C(C(=NC(=N1)N)N)C2=CC=C(C=C2)Cl",
462
+ "C1=CC(=CC=C1C(=O)N[C@@H](CCC(=O)O)C(=O)O)NCC2=CN=C3C(=N2)C(=NC(=N3)N)N",
463
  ]),
464
  list(MODEL_REPOS)[0],
465
  list(EXTRA_METRICS)[:2],
466
+ ], # CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin
467
 
468
  ],
469
  example_labels=[
 
471
  "Doxorubicin, Ampicillin, Amoxicillin, Meropenem, Tetracycline, Anhydrotetracycline",
472
  "Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid",
473
  "Murepavadin, Vancomycin, Zosurabalpin, Plazomicin, Gentamicin, Rifampicin",
474
+ "Debio-1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide",
475
+ "Trimethoprim, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed",
476
+ "CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin"
477
  ],
478
  inputs=[input_line, output_species_single, extra_metric],
479
  cache_mode="eager",
 
517
  outputs=download_single
518
  )
519
 
520
+ with gr.Tab(f"Predict on structures from a file (max. {MAX_ROWS} rows, single species)"):
521
  input_file = gr.File(
522
  label="Upload a table of chemical compounds here",
523
  file_types=[".xlsx", ".csv", ".tsv", ".txt"],
 
565
  )
566
  with gr.Row():
567
  observed_col = gr.Dropdown(
568
+ label="Observed column (y-axis) for left plot",
569
  choices=[],
570
  value=None,
571
  interactive=True,
572
  visible=False,
573
  )
574
  color_col = gr.Dropdown(
575
+ label="Color for left plot",
576
+ choices=[],
577
+ value=None,
578
+ interactive=True,
579
+ visible=False,
580
+ )
581
+
582
+ any_x_col = gr.Dropdown(
583
+ label="x-axis for right plot",
584
+ choices=[],
585
+ value=None,
586
+ interactive=True,
587
+ visible=False,
588
+ )
589
+ any_y_col = gr.Dropdown(
590
+ label="y-axis for right plot",
591
+ choices=[],
592
+ value=None,
593
+ interactive=True,
594
+ visible=False,
595
+ )
596
+ any_color_col = gr.Dropdown(
597
+ label="Color for right plot",
598
  choices=[],
599
  value=None,
600
  interactive=True,
 
607
  file_examples = gr.Examples(
608
  examples=[
609
  [
610
+ "example-data/stokes2020-eco.csv",
611
  "SMILES",
612
  "Klebsiella pneumoniae",
613
  "Mean_Inhibition",
614
  "Klebsiella pneumoniae: Doubtscore",
615
+ list(EXTRA_METRICS)[:3],
616
+ ],
617
+ [
618
+ "example-data/liu23-abau.csv",
619
+ "SMILES",
620
+ "Klebsiella pneumoniae",
621
+ "Mean",
622
+ "Klebsiella pneumoniae: Doubtscore",
623
+ list(EXTRA_METRICS)[:3],
624
+ ],
625
+ [
626
+ "example-data/wong24-sau-tox-5000.csv",
627
+ "SMILES",
628
+ "Klebsiella pneumoniae",
629
+ "Mean",
630
+ "Klebsiella pneumoniae: Doubtscore",
631
+ list(EXTRA_METRICS)[:3],
632
+ ],
633
  ],
634
  example_labels=[
635
+ "E. coli training data from Stokes J. et al., Cell, 2020",
636
+ "A. baumannii training data from Liu, 2023",
637
+ "S. aureus and toxicity training data from Wong, 2024",
638
  ],
639
  inputs=[input_file, input_column, output_species, observed_col, color_col, extra_metric_file],
640
  cache_mode="eager",
641
  )
642
+ with gr.Row():
643
+ pred_vs_observed = gr.ScatterPlot(
644
+ label="Prediction vs observed",
645
+ x_title="Predicted MIC (µM)",
646
+ y_title="Observed",
647
+ visible=False,
648
+ height=600,
649
+ )
650
+ plot_any_vs_any = gr.ScatterPlot(
651
+ label="Any vs any",
652
+ visible=False,
653
+ height=600,
654
+ )
655
+
656
+ load_data_action = {
657
+ "fn": load_input_data,
658
+ "inputs": [input_file],
659
+ "outputs": [input_data, input_column]
660
+ }
661
 
662
  file_examples.load_input_event.then(
663
+ **load_data_action,
 
 
664
  )
665
  input_file.upload(
666
+ **load_data_action,
 
 
667
  )
668
+ go2_click_event = go_button2.click(
669
  predict_file,
670
  inputs=[
671
  input_data,
 
681
  download_table,
682
  inputs=input_data,
683
  outputs=download
 
 
 
 
 
 
 
 
684
  ).then(
685
  lambda: gr.Button(visible=True),
686
+ outputs=[plot_button]
687
  )
688
+
689
+ for dropdown in [observed_col, color_col, any_color_col, any_x_col, any_y_col]:
690
+ go2_click_event.then(
691
+ partial(get_dropdown_options, _type="number"),
692
+ inputs=[input_data],
693
+ outputs=[dropdown],
694
+ )
695
 
696
  plot_button.click(
697
  plot_pred_vs_observed,
 
701
  observed_col,
702
  color_col,
703
  ],
704
+ outputs=[pred_vs_observed],
705
+ ).then(
706
+ plot_x_vs_y,
707
+ inputs=[
708
+ input_data,
709
+ any_x_col,
710
+ any_y_col,
711
+ any_color_col,
712
+ ],
713
+ outputs=[plot_any_vs_any],
714
  )
715
 
716
  if __name__ == "__main__":
example-data/liu23-abau.csv ADDED
The diff for this file is too large to render. See raw diff
 
example-data/{stokes2020-eco-1000.csv → stokes2020-eco.csv} RENAMED
The diff for this file is too large to render. See raw diff
 
example-data/wong24-sau-tox-5000.csv ADDED
The diff for this file is too large to render. See raw diff