Spaces:
Running
Running
Eachan Johnson
commited on
Commit
·
a11f59c
1
Parent(s):
7dbc2af
Update examples and add second plot
Browse files
app.py
CHANGED
@@ -23,6 +23,7 @@ from schemist.tables import converter
|
|
23 |
import torch
|
24 |
|
25 |
CACHE = "./cache"
|
|
|
26 |
HEADER_FILE = os.path.join("sources", "header.md")
|
27 |
MODEL_REPOS = {
|
28 |
"Klebsiella pneumoniae": "hf://scbirlab/spark-dv-fp-2503-kpn",
|
@@ -78,6 +79,10 @@ def convert_one(
|
|
78 |
input_representation: str = 'smiles',
|
79 |
output_representation: Union[Iterable[str], str] = 'smiles'
|
80 |
):
|
|
|
|
|
|
|
|
|
81 |
|
82 |
df = pd.DataFrame({
|
83 |
input_representation: _clean_split_input(strings),
|
@@ -91,23 +96,17 @@ def convert_one(
|
|
91 |
)
|
92 |
|
93 |
|
94 |
-
def
|
95 |
-
|
96 |
-
input_representation: str = 'smiles',
|
97 |
predict: Union[Iterable[str], str] = 'smiles',
|
98 |
extra_metrics: Optional[Union[Iterable[str], str]] = None
|
99 |
-
):
|
|
|
|
|
100 |
if extra_metrics is None:
|
101 |
extra_metrics = []
|
102 |
else:
|
103 |
extra_metrics = cast(extra_metrics, to=list)
|
104 |
-
prediction_df = convert_one(
|
105 |
-
strings=strings,
|
106 |
-
input_representation=input_representation,
|
107 |
-
output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
|
108 |
-
)
|
109 |
-
species_to_predict = cast(predict, to=list)
|
110 |
-
prediction_cols = []
|
111 |
for species in species_to_predict:
|
112 |
message = f"Predicting for species: {species}"
|
113 |
print_err(message)
|
@@ -116,7 +115,7 @@ def predict_one(
|
|
116 |
this_features = this_modelbox._input_cols
|
117 |
this_labels = this_modelbox._label_cols
|
118 |
this_prediction_input = (
|
119 |
-
|
120 |
.rename(columns={
|
121 |
"smiles": this_features[0],
|
122 |
})
|
@@ -132,10 +131,10 @@ def predict_one(
|
|
132 |
).with_format("numpy")["__prediction__"].flatten()
|
133 |
print(prediction)
|
134 |
this_col = f"{species}: predicted MIC (µM)"
|
135 |
-
|
136 |
prediction_cols.append(this_col)
|
137 |
this_col = f"{species}: predicted MIC (µg / mL)"
|
138 |
-
|
139 |
prediction_cols.append(this_col)
|
140 |
|
141 |
for extra_metric in extra_metrics:
|
@@ -155,10 +154,33 @@ def predict_one(
|
|
155 |
)
|
156 |
.with_format("numpy")
|
157 |
)
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
return gr.DataFrame(
|
161 |
-
prediction_df[
|
|
|
|
|
|
|
|
|
162 |
visible=True
|
163 |
)
|
164 |
|
@@ -209,70 +231,38 @@ def predict_file(
|
|
209 |
else:
|
210 |
extra_metrics = cast(extra_metrics, to=list)
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
prediction_df = convert_file(
|
213 |
df,
|
214 |
column=column,
|
215 |
input_representation=input_representation,
|
216 |
output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
|
217 |
)
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
)
|
239 |
-
print(this_prediction_input)
|
240 |
-
prediction = this_modelbox.predict(
|
241 |
-
data=this_prediction_input,
|
242 |
-
features=this_features,
|
243 |
-
labels=this_labels,
|
244 |
-
aggregator="mean",
|
245 |
-
cache=CACHE,
|
246 |
-
).with_format("numpy")["__prediction__"].flatten()
|
247 |
-
print(prediction)
|
248 |
-
this_col = f"{species}: predicted MIC (µM)"
|
249 |
-
prediction_df[this_col] = np.power(10., -prediction) * 1e6
|
250 |
-
prediction_cols.append(this_col)
|
251 |
-
this_col = f"{species}: predicted MIC (µg / mL)"
|
252 |
-
prediction_df[this_col] = np.power(10., -prediction) * 1e3 * prediction_df["mwt"]
|
253 |
-
prediction_cols.append(this_col)
|
254 |
-
|
255 |
-
for extra_metric in extra_metrics:
|
256 |
-
message = f"Calculating {extra_metric} for species: {species}"
|
257 |
-
print_err(message)
|
258 |
-
gr.Info(message, duration=10)
|
259 |
-
# this_modelbox._input_training_data = this_modelbox._input_training_data.remove_columns([this_modelbox._in_key])
|
260 |
-
this_col = f"{species}: {extra_metric}"
|
261 |
-
prediction_cols.append(this_col)
|
262 |
-
print(">>>", this_modelbox._input_training_data)
|
263 |
-
print(">>>", this_modelbox._input_training_data.format)
|
264 |
-
print(">>>", this_modelbox._in_key, this_modelbox._out_key)
|
265 |
-
this_extra = (
|
266 |
-
EXTRA_METRICS[extra_metric](
|
267 |
-
this_modelbox,
|
268 |
-
this_prediction_input,
|
269 |
-
)
|
270 |
-
.with_format("numpy")
|
271 |
-
)
|
272 |
-
prediction_df[this_col] = this_extra[this_extra.column_names[-1]]
|
273 |
-
other_cols = [col for col in prediction_df if col not in ['id', 'inchikey', 'smiles', "mwt", "clogp"] + [column] + prediction_cols]
|
274 |
-
|
275 |
-
return prediction_df[['id', 'inchikey'] + [column] + prediction_cols + other_cols + ['smiles', "mwt", "clogp"]]
|
276 |
|
277 |
def draw_one(
|
278 |
strings: Union[Iterable[str], str],
|
@@ -293,31 +283,35 @@ def draw_one(
|
|
293 |
legends=["\n".join(items) for items in zip(*_ids.values())],
|
294 |
)
|
295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
|
|
|
298 |
df,
|
299 |
-
|
300 |
-
|
301 |
color: Optional[str] = None,
|
302 |
):
|
303 |
print_err(df.head())
|
304 |
-
|
305 |
-
|
306 |
-
y_title = f"Observed ({ycol})"
|
307 |
-
cols = ["id", "inchikey", "smiles", "mwt", "clogp", xcol, ycol]
|
308 |
-
color_title = color
|
309 |
if color is not None and color not in cols:
|
310 |
cols.append(color)
|
311 |
cols = list(set(cols))
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
x_title = f"Predicted log10[MIC(µM)]"
|
316 |
|
317 |
return gr.ScatterPlot(
|
318 |
value=df[cols],
|
319 |
-
x=
|
320 |
-
y=
|
321 |
color=color,
|
322 |
x_title=x_title,
|
323 |
y_title=y_title,
|
@@ -327,14 +321,32 @@ def plot_pred_vs_observed(
|
|
327 |
)
|
328 |
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
def download_table(
|
331 |
df: pd.DataFrame
|
332 |
) -> str:
|
333 |
df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
|
334 |
-
filename = f"
|
335 |
df.to_csv(filename, index=False)
|
336 |
return gr.DownloadButton(value=filename, visible=True)
|
337 |
|
|
|
338 |
with gr.Blocks() as demo:
|
339 |
|
340 |
with open(HEADER_FILE, 'r') as f:
|
@@ -379,7 +391,7 @@ with gr.Blocks() as demo:
|
|
379 |
]),
|
380 |
list(MODEL_REPOS)[0],
|
381 |
list(EXTRA_METRICS)[:2],
|
382 |
-
|
383 |
[
|
384 |
'\n'.join([
|
385 |
"C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
|
@@ -399,6 +411,7 @@ with gr.Blocks() as demo:
|
|
399 |
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
|
400 |
"CC1=CC(=NO1)NS(=O)(=O)C2=CC=C(C=C2)N",
|
401 |
"C1[C@@H]([C@H]([C@@H]([C@H]([C@@H]1NC(=O)[C@H](CCN)O)O[C@@H]2[C@@H]([C@H]([C@@H]([C@H](O2)CO)O)N)O)O)O[C@@H]3[C@@H]([C@H]([C@@H]([C@H](O3)CN)O)O)O)N\nC1=CN=CC=C1C(=O)NN",
|
|
|
402 |
]),
|
403 |
list(MODEL_REPOS)[0],
|
404 |
list(EXTRA_METRICS)[:2],
|
@@ -420,10 +433,37 @@ with gr.Blocks() as demo:
|
|
420 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)CC4)N=C3",
|
421 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@@H](C4)N)N=C3",
|
422 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@H](CC4)[NH3+])N=C3.[Cl-]",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
]),
|
424 |
list(MODEL_REPOS)[0],
|
425 |
list(EXTRA_METRICS)[:2],
|
426 |
-
], #
|
427 |
|
428 |
],
|
429 |
example_labels=[
|
@@ -431,8 +471,9 @@ with gr.Blocks() as demo:
|
|
431 |
"Doxorubicin, Ampicillin, Amoxicillin, Meropenem, Tetracycline, Anhydrotetracycline",
|
432 |
"Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid",
|
433 |
"Murepavadin, Vancomycin, Zosurabalpin, Plazomicin, Gentamicin, Rifampicin",
|
434 |
-
"Debio-1452, Debio-1452-NH3, Fabimycin",
|
435 |
-
|
|
|
436 |
],
|
437 |
inputs=[input_line, output_species_single, extra_metric],
|
438 |
cache_mode="eager",
|
@@ -476,7 +517,7 @@ with gr.Blocks() as demo:
|
|
476 |
outputs=download_single
|
477 |
)
|
478 |
|
479 |
-
with gr.Tab("Predict on structures from a file (max.
|
480 |
input_file = gr.File(
|
481 |
label="Upload a table of chemical compounds here",
|
482 |
file_types=[".xlsx", ".csv", ".tsv", ".txt"],
|
@@ -524,14 +565,36 @@ with gr.Blocks() as demo:
|
|
524 |
)
|
525 |
with gr.Row():
|
526 |
observed_col = gr.Dropdown(
|
527 |
-
label="Observed column (y-axis) for
|
528 |
choices=[],
|
529 |
value=None,
|
530 |
interactive=True,
|
531 |
visible=False,
|
532 |
)
|
533 |
color_col = gr.Dropdown(
|
534 |
-
label="Color for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
choices=[],
|
536 |
value=None,
|
537 |
interactive=True,
|
@@ -544,38 +607,65 @@ with gr.Blocks() as demo:
|
|
544 |
file_examples = gr.Examples(
|
545 |
examples=[
|
546 |
[
|
547 |
-
"example-data/stokes2020-eco
|
548 |
"SMILES",
|
549 |
"Klebsiella pneumoniae",
|
550 |
"Mean_Inhibition",
|
551 |
"Klebsiella pneumoniae: Doubtscore",
|
552 |
-
list(EXTRA_METRICS)[:3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
],
|
554 |
example_labels=[
|
555 |
-
"Stokes J. et al., Cell, 2020",
|
|
|
|
|
556 |
],
|
557 |
inputs=[input_file, input_column, output_species, observed_col, color_col, extra_metric_file],
|
558 |
cache_mode="eager",
|
559 |
)
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
|
568 |
file_examples.load_input_event.then(
|
569 |
-
|
570 |
-
inputs=[input_file],
|
571 |
-
outputs=[input_data, input_column],
|
572 |
)
|
573 |
input_file.upload(
|
574 |
-
|
575 |
-
inputs=[input_file],
|
576 |
-
outputs=[input_data, input_column]
|
577 |
)
|
578 |
-
go_button2.click(
|
579 |
predict_file,
|
580 |
inputs=[
|
581 |
input_data,
|
@@ -591,18 +681,17 @@ with gr.Blocks() as demo:
|
|
591 |
download_table,
|
592 |
inputs=input_data,
|
593 |
outputs=download
|
594 |
-
).then(
|
595 |
-
partial(get_dropdown_options, _type="number"),
|
596 |
-
inputs=[input_data],
|
597 |
-
outputs=[observed_col],
|
598 |
-
).then(
|
599 |
-
partial(get_dropdown_options, _type="number"),
|
600 |
-
inputs=[input_data],
|
601 |
-
outputs=[color_col],
|
602 |
).then(
|
603 |
lambda: gr.Button(visible=True),
|
604 |
-
outputs=[plot_button]
|
605 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
|
607 |
plot_button.click(
|
608 |
plot_pred_vs_observed,
|
@@ -612,7 +701,16 @@ with gr.Blocks() as demo:
|
|
612 |
observed_col,
|
613 |
color_col,
|
614 |
],
|
615 |
-
outputs=pred_vs_observed,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
616 |
)
|
617 |
|
618 |
if __name__ == "__main__":
|
|
|
23 |
import torch
|
24 |
|
25 |
CACHE = "./cache"
|
26 |
+
MAX_ROWS = 4000
|
27 |
HEADER_FILE = os.path.join("sources", "header.md")
|
28 |
MODEL_REPOS = {
|
29 |
"Klebsiella pneumoniae": "hf://scbirlab/spark-dv-fp-2503-kpn",
|
|
|
79 |
input_representation: str = 'smiles',
|
80 |
output_representation: Union[Iterable[str], str] = 'smiles'
|
81 |
):
|
82 |
+
output_representation = cast(output_representation, to=list)
|
83 |
+
for rep in output_representation:
|
84 |
+
message = f"Converting from {input_representation} to {rep}..."
|
85 |
+
gr.Info(message, duration=10)
|
86 |
|
87 |
df = pd.DataFrame({
|
88 |
input_representation: _clean_split_input(strings),
|
|
|
96 |
)
|
97 |
|
98 |
|
99 |
+
def _prediction_loop(
|
100 |
+
df: pd.DataFrame,
|
|
|
101 |
predict: Union[Iterable[str], str] = 'smiles',
|
102 |
extra_metrics: Optional[Union[Iterable[str], str]] = None
|
103 |
+
) -> pd.DataFrame:
|
104 |
+
species_to_predict = cast(predict, to=list)
|
105 |
+
prediction_cols = []
|
106 |
if extra_metrics is None:
|
107 |
extra_metrics = []
|
108 |
else:
|
109 |
extra_metrics = cast(extra_metrics, to=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
for species in species_to_predict:
|
111 |
message = f"Predicting for species: {species}"
|
112 |
print_err(message)
|
|
|
115 |
this_features = this_modelbox._input_cols
|
116 |
this_labels = this_modelbox._label_cols
|
117 |
this_prediction_input = (
|
118 |
+
df
|
119 |
.rename(columns={
|
120 |
"smiles": this_features[0],
|
121 |
})
|
|
|
131 |
).with_format("numpy")["__prediction__"].flatten()
|
132 |
print(prediction)
|
133 |
this_col = f"{species}: predicted MIC (µM)"
|
134 |
+
df[this_col] = np.power(10., -prediction) * 1e6
|
135 |
prediction_cols.append(this_col)
|
136 |
this_col = f"{species}: predicted MIC (µg / mL)"
|
137 |
+
df[this_col] = np.power(10., -prediction) * 1e3 * df["mwt"]
|
138 |
prediction_cols.append(this_col)
|
139 |
|
140 |
for extra_metric in extra_metrics:
|
|
|
154 |
)
|
155 |
.with_format("numpy")
|
156 |
)
|
157 |
+
df[this_col] = this_extra[this_extra.column_names[-1]]
|
158 |
+
|
159 |
+
return prediction_cols, df
|
160 |
+
|
161 |
+
|
162 |
+
def predict_one(
|
163 |
+
strings: str,
|
164 |
+
input_representation: str = 'smiles',
|
165 |
+
predict: Union[Iterable[str], str] = 'smiles',
|
166 |
+
extra_metrics: Optional[Union[Iterable[str], str]] = None
|
167 |
+
):
|
168 |
+
prediction_df = convert_one(
|
169 |
+
strings=strings,
|
170 |
+
input_representation=input_representation,
|
171 |
+
output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
|
172 |
+
)
|
173 |
+
prediction_cols, prediction_df = _prediction_loop(
|
174 |
+
prediction_df,
|
175 |
+
predict=predict,
|
176 |
+
extra_metrics=extra_metrics,
|
177 |
+
)
|
178 |
return gr.DataFrame(
|
179 |
+
prediction_df[
|
180 |
+
['id', 'pubchem_name', 'pubchem_id']
|
181 |
+
+ prediction_cols
|
182 |
+
+ ['smiles', 'inchikey', "mwt", "clogp"]
|
183 |
+
],
|
184 |
visible=True
|
185 |
)
|
186 |
|
|
|
231 |
else:
|
232 |
extra_metrics = cast(extra_metrics, to=list)
|
233 |
|
234 |
+
if df.shape[0] > MAX_ROWS:
|
235 |
+
message = f"Truncating input to {MAX_ROWS} rows"
|
236 |
+
print_err(message)
|
237 |
+
gr.Info(message, duration=15)
|
238 |
+
df = df.iloc[:MAX_ROWS]
|
239 |
+
|
240 |
prediction_df = convert_file(
|
241 |
df,
|
242 |
column=column,
|
243 |
input_representation=input_representation,
|
244 |
output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
|
245 |
)
|
246 |
+
prediction_cols, prediction_df = _prediction_loop(
|
247 |
+
prediction_df,
|
248 |
+
predict=predict,
|
249 |
+
extra_metrics=extra_metrics,
|
250 |
+
)
|
251 |
+
main_cols = set(
|
252 |
+
['id', 'inchikey', 'smiles', "mwt", "clogp"]
|
253 |
+
+ [column]
|
254 |
+
+ prediction_cols
|
255 |
+
)
|
256 |
+
other_cols = [
|
257 |
+
col for col in prediction_df
|
258 |
+
if col not in main_cols
|
259 |
+
]
|
260 |
+
return prediction_df[
|
261 |
+
['id', 'inchikey']
|
262 |
+
+ [column]
|
263 |
+
+ prediction_cols + other_cols
|
264 |
+
+ ['smiles', "mwt", "clogp"]
|
265 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
def draw_one(
|
268 |
strings: Union[Iterable[str], str],
|
|
|
283 |
legends=["\n".join(items) for items in zip(*_ids.values())],
|
284 |
)
|
285 |
|
286 |
+
def log10_if_all_positive(df, col):
|
287 |
+
if np.all(df[col] > 0.):
|
288 |
+
df[col] = np.log10(df[col])
|
289 |
+
title = f"log10[ {col} ]"
|
290 |
+
else:
|
291 |
+
title = col
|
292 |
+
return title, df
|
293 |
|
294 |
+
|
295 |
+
def plot_x_vs_y(
|
296 |
df,
|
297 |
+
x: str,
|
298 |
+
y: str,
|
299 |
color: Optional[str] = None,
|
300 |
):
|
301 |
print_err(df.head())
|
302 |
+
y_title = y
|
303 |
+
cols = ["id", "inchikey", "smiles", "mwt", "clogp", x, y]
|
|
|
|
|
|
|
304 |
if color is not None and color not in cols:
|
305 |
cols.append(color)
|
306 |
cols = list(set(cols))
|
307 |
+
x_title, df = log10_if_all_positive(df, x)
|
308 |
+
y_title, df = log10_if_all_positive(df, y)
|
309 |
+
color_title, df = log10_if_all_positive(df, color)
|
|
|
310 |
|
311 |
return gr.ScatterPlot(
|
312 |
value=df[cols],
|
313 |
+
x=x,
|
314 |
+
y=y,
|
315 |
color=color,
|
316 |
x_title=x_title,
|
317 |
y_title=y_title,
|
|
|
321 |
)
|
322 |
|
323 |
|
324 |
+
def plot_pred_vs_observed(
|
325 |
+
df,
|
326 |
+
species: str,
|
327 |
+
observed: str,
|
328 |
+
color: Optional[str] = None,
|
329 |
+
):
|
330 |
+
print_err(df.head())
|
331 |
+
xcol = f"{species}: predicted MIC (µM)"
|
332 |
+
ycol = observed
|
333 |
+
return plot_x_vs_y(
|
334 |
+
df,
|
335 |
+
x=xcol,
|
336 |
+
y=ycol,
|
337 |
+
color=color,
|
338 |
+
)
|
339 |
+
|
340 |
+
|
341 |
def download_table(
|
342 |
df: pd.DataFrame
|
343 |
) -> str:
|
344 |
df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
|
345 |
+
filename = f"predicted-{df_hash}.csv"
|
346 |
df.to_csv(filename, index=False)
|
347 |
return gr.DownloadButton(value=filename, visible=True)
|
348 |
|
349 |
+
|
350 |
with gr.Blocks() as demo:
|
351 |
|
352 |
with open(HEADER_FILE, 'r') as f:
|
|
|
391 |
]),
|
392 |
list(MODEL_REPOS)[0],
|
393 |
list(EXTRA_METRICS)[:2],
|
394 |
+
], # cipro, ceftriaxone, cefiderocol, linezolid, gepotidacin
|
395 |
[
|
396 |
'\n'.join([
|
397 |
"C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
|
|
|
411 |
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
|
412 |
"CC1=CC(=NO1)NS(=O)(=O)C2=CC=C(C=C2)N",
|
413 |
"C1[C@@H]([C@H]([C@@H]([C@H]([C@@H]1NC(=O)[C@H](CCN)O)O[C@@H]2[C@@H]([C@H]([C@@H]([C@H](O2)CO)O)N)O)O)O[C@@H]3[C@@H]([C@H]([C@@H]([C@H](O3)CN)O)O)O)N\nC1=CN=CC=C1C(=O)NN",
|
414 |
+
"C1=CN=CC=C1C(=O)NN ",
|
415 |
]),
|
416 |
list(MODEL_REPOS)[0],
|
417 |
list(EXTRA_METRICS)[:2],
|
|
|
433 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)CC4)N=C3",
|
434 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@@H](C4)N)N=C3",
|
435 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@H](CC4)[NH3+])N=C3.[Cl-]",
|
436 |
+
"C1=C(C(=O)NC(=O)N1)F",
|
437 |
+
"CCCCCCNC(=O)N1C=C(C(=O)NC1=O)F",
|
438 |
+
"C[C@@H]1OC[C@@H]2[C@@H](O1)[C@@H]([C@H]([C@@H](O2)O[C@H]3[C@H]4COC(=O)[C@@H]4[C@@H](C5=CC6=C(C=C35)OCO6)C7=CC(=C(C(=C7)OC)O)OC)O)O",
|
439 |
+
]),
|
440 |
+
list(MODEL_REPOS)[0],
|
441 |
+
list(EXTRA_METRICS)[:2],
|
442 |
+
], # Debio1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide
|
443 |
+
[
|
444 |
+
'\n'.join([
|
445 |
+
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
|
446 |
+
"CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
|
447 |
+
"C1=CC(=CC=C1CCC2=CNC3=C2C(=O)NC(=N3)N)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
|
448 |
+
"CC1=C(C2=C(C=C1)N=C(NC2=O)N)SC3=CC=NC=C3",
|
449 |
+
"CN(CC1=CN=C2C(=N1)C(=NC(=N2)N)N)C3=CC=C(C=C3)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
|
450 |
+
"CC1=NC2=C(C=C(C=C2)CN(C)C3=CC=C(S3)C(=O)N[C@@H](CCC(=O)O)C(=O)O)C(=O)N1",
|
451 |
+
]),
|
452 |
+
list(MODEL_REPOS)[0],
|
453 |
+
list(EXTRA_METRICS)[:2],
|
454 |
+
], # Trimethoprim, SCH79797, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed
|
455 |
+
[
|
456 |
+
'\n'.join([
|
457 |
+
"C[C@H]([C@@H](C(=O)NO)NC(=O)C1=CC=C(C=C1)C#CC2=CC=C(C=C2)CN3CCOCC3)O",
|
458 |
+
"CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
|
459 |
+
"C1=CC=C(C=C1)CNC2=NC(=NC3=CC=CC=C32)NCC4=CC=CC=C4",
|
460 |
+
"CC(C)(C)C1=CC=C(C=C1)C(=O)NC(=S)NC2=CC=C(C=C2)NC(=O)CCCCN(C)C",
|
461 |
+
"CCC1=C(C(=NC(=N1)N)N)C2=CC=C(C=C2)Cl",
|
462 |
+
"C1=CC(=CC=C1C(=O)N[C@@H](CCC(=O)O)C(=O)O)NCC2=CN=C3C(=N2)C(=NC(=N3)N)N",
|
463 |
]),
|
464 |
list(MODEL_REPOS)[0],
|
465 |
list(EXTRA_METRICS)[:2],
|
466 |
+
], # CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin
|
467 |
|
468 |
],
|
469 |
example_labels=[
|
|
|
471 |
"Doxorubicin, Ampicillin, Amoxicillin, Meropenem, Tetracycline, Anhydrotetracycline",
|
472 |
"Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid",
|
473 |
"Murepavadin, Vancomycin, Zosurabalpin, Plazomicin, Gentamicin, Rifampicin",
|
474 |
+
"Debio-1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide",
|
475 |
+
"Trimethoprim, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed",
|
476 |
+
"CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin"
|
477 |
],
|
478 |
inputs=[input_line, output_species_single, extra_metric],
|
479 |
cache_mode="eager",
|
|
|
517 |
outputs=download_single
|
518 |
)
|
519 |
|
520 |
+
with gr.Tab(f"Predict on structures from a file (max. {MAX_ROWS} rows, single species)"):
|
521 |
input_file = gr.File(
|
522 |
label="Upload a table of chemical compounds here",
|
523 |
file_types=[".xlsx", ".csv", ".tsv", ".txt"],
|
|
|
565 |
)
|
566 |
with gr.Row():
|
567 |
observed_col = gr.Dropdown(
|
568 |
+
label="Observed column (y-axis) for left plot",
|
569 |
choices=[],
|
570 |
value=None,
|
571 |
interactive=True,
|
572 |
visible=False,
|
573 |
)
|
574 |
color_col = gr.Dropdown(
|
575 |
+
label="Color for left plot",
|
576 |
+
choices=[],
|
577 |
+
value=None,
|
578 |
+
interactive=True,
|
579 |
+
visible=False,
|
580 |
+
)
|
581 |
+
|
582 |
+
any_x_col = gr.Dropdown(
|
583 |
+
label="x-axis for right plot",
|
584 |
+
choices=[],
|
585 |
+
value=None,
|
586 |
+
interactive=True,
|
587 |
+
visible=False,
|
588 |
+
)
|
589 |
+
any_y_col = gr.Dropdown(
|
590 |
+
label="y-axis for right plot",
|
591 |
+
choices=[],
|
592 |
+
value=None,
|
593 |
+
interactive=True,
|
594 |
+
visible=False,
|
595 |
+
)
|
596 |
+
any_color_col = gr.Dropdown(
|
597 |
+
label="Color for right plot",
|
598 |
choices=[],
|
599 |
value=None,
|
600 |
interactive=True,
|
|
|
607 |
file_examples = gr.Examples(
|
608 |
examples=[
|
609 |
[
|
610 |
+
"example-data/stokes2020-eco.csv",
|
611 |
"SMILES",
|
612 |
"Klebsiella pneumoniae",
|
613 |
"Mean_Inhibition",
|
614 |
"Klebsiella pneumoniae: Doubtscore",
|
615 |
+
list(EXTRA_METRICS)[:3],
|
616 |
+
],
|
617 |
+
[
|
618 |
+
"example-data/liu23-abau.csv",
|
619 |
+
"SMILES",
|
620 |
+
"Klebsiella pneumoniae",
|
621 |
+
"Mean",
|
622 |
+
"Klebsiella pneumoniae: Doubtscore",
|
623 |
+
list(EXTRA_METRICS)[:3],
|
624 |
+
],
|
625 |
+
[
|
626 |
+
"example-data/wong24-sau-tox-5000.csv",
|
627 |
+
"SMILES",
|
628 |
+
"Klebsiella pneumoniae",
|
629 |
+
"Mean",
|
630 |
+
"Klebsiella pneumoniae: Doubtscore",
|
631 |
+
list(EXTRA_METRICS)[:3],
|
632 |
+
],
|
633 |
],
|
634 |
example_labels=[
|
635 |
+
"E. coli training data from Stokes J. et al., Cell, 2020",
|
636 |
+
"A. baumannii training data from Liu, 2023",
|
637 |
+
"S. aureus and toxicity training data from Wong, 2024",
|
638 |
],
|
639 |
inputs=[input_file, input_column, output_species, observed_col, color_col, extra_metric_file],
|
640 |
cache_mode="eager",
|
641 |
)
|
642 |
+
with gr.Row():
|
643 |
+
pred_vs_observed = gr.ScatterPlot(
|
644 |
+
label="Prediction vs observed",
|
645 |
+
x_title="Predicted MIC (µM)",
|
646 |
+
y_title="Observed",
|
647 |
+
visible=False,
|
648 |
+
height=600,
|
649 |
+
)
|
650 |
+
plot_any_vs_any = gr.ScatterPlot(
|
651 |
+
label="Any vs any",
|
652 |
+
visible=False,
|
653 |
+
height=600,
|
654 |
+
)
|
655 |
+
|
656 |
+
load_data_action = {
|
657 |
+
"fn": load_input_data,
|
658 |
+
"inputs": [input_file],
|
659 |
+
"outputs": [input_data, input_column]
|
660 |
+
}
|
661 |
|
662 |
file_examples.load_input_event.then(
|
663 |
+
**load_data_action,
|
|
|
|
|
664 |
)
|
665 |
input_file.upload(
|
666 |
+
**load_data_action,
|
|
|
|
|
667 |
)
|
668 |
+
go2_click_event = go_button2.click(
|
669 |
predict_file,
|
670 |
inputs=[
|
671 |
input_data,
|
|
|
681 |
download_table,
|
682 |
inputs=input_data,
|
683 |
outputs=download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
).then(
|
685 |
lambda: gr.Button(visible=True),
|
686 |
+
outputs=[plot_button]
|
687 |
)
|
688 |
+
|
689 |
+
for dropdown in [observed_col, color_col, any_color_col, any_x_col, any_y_col]:
|
690 |
+
go2_click_event.then(
|
691 |
+
partial(get_dropdown_options, _type="number"),
|
692 |
+
inputs=[input_data],
|
693 |
+
outputs=[dropdown],
|
694 |
+
)
|
695 |
|
696 |
plot_button.click(
|
697 |
plot_pred_vs_observed,
|
|
|
701 |
observed_col,
|
702 |
color_col,
|
703 |
],
|
704 |
+
outputs=[pred_vs_observed],
|
705 |
+
).then(
|
706 |
+
plot_x_vs_y,
|
707 |
+
inputs=[
|
708 |
+
input_data,
|
709 |
+
any_x_col,
|
710 |
+
any_y_col,
|
711 |
+
any_color_col,
|
712 |
+
],
|
713 |
+
outputs=[plot_any_vs_any],
|
714 |
)
|
715 |
|
716 |
if __name__ == "__main__":
|
example-data/liu23-abau.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
example-data/{stokes2020-eco-1000.csv → stokes2020-eco.csv}
RENAMED
The diff for this file is too large to render.
See raw diff
|
|
example-data/wong24-sau-tox-5000.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|