mic-predict / app.py
Eachan Johnson
Tidy code
7f2265f
raw
history blame
27.4 kB
"""Gradio demo for schemist."""
from typing import Iterable, List, Optional, Union
from functools import partial
from io import TextIOWrapper
import json
import os
# os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue"
from carabiner import cast, print_err
from carabiner.pd import read_table
from duvida.autoclass import AutoModelBox
import gradio as gr
import nemony as nm
import numpy as np
import pandas as pd
from rdkit.Chem import Draw, Mol
from schemist.converting import (
_FROM_FUNCTIONS,
convert_string_representation,
_x2mol,
)
from schemist.tables import converter
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CACHE = "./cache"
MAX_ROWS = 4000
BATCH_SIZE=32
HEADER_FILE = os.path.join("sources", "header.md")
with open("repos.json", "r") as f:
MODEL_REPOS = json.load(f)
MODELBOXES = {
key: AutoModelBox.from_pretrained(val, cache_dir=CACHE)
for key, val in MODEL_REPOS.items()
}
[mb.to(DEVICE) for mb in MODELBOXES.values()]
EXTRA_METRICS = {
"log10(variance)": lambda modelbox, candidates: modelbox.prediction_variance(candidates=candidates, batch_size=BATCH_SIZE, cache=CACHE).map(lambda x: {modelbox._variance_key: torch.log10(x[modelbox._variance_key])}),
"Tanimoto nearest neighbor to training data": lambda modelbox, candidates: modelbox.tanimoto_nn(candidates=candidates, batch_size=BATCH_SIZE),
"Doubtscore": lambda modelbox, candidates: modelbox.doubtscore(candidates=candidates, cache=CACHE, batch_size=BATCH_SIZE).map(lambda x: {"doubtscore": torch.log10(x["doubtscore"])}),
"Information sensitivity (approx.)": lambda modelbox, candidates: modelbox.information_sensitivity(candidates=candidates, batch_size=BATCH_SIZE, optimality_approximation=True, approximator="squared_jacobian", cache=CACHE).map(lambda x: {"information sensitivity": torch.log10(x["information sensitivity"])}),
}
def get_dropdown_options(df, _type = str):
if _type == str:
cols = list(df.select_dtypes(exclude=[np.number]))
else:
cols = list(df.select_dtypes([np.number]))
return gr.Dropdown(choices=cols, interactive=True, value=cols[0], visible=True)
def load_input_data(file: Union[TextIOWrapper, str]) -> pd.DataFrame:
file = file if isinstance(file, str) else file.name
print_err(f"Loading {file}")
df = read_table(file)
print_err(df.head())
return gr.Dataframe(value=df, visible=True), get_dropdown_options(df, str)
def _clean_split_input(strings: str) -> List[str]:
return [s2.strip() for s in strings.split("\n") for s2 in s.split(",")]
def _convert_input(
strings: str,
input_representation: str = 'smiles',
output_representation: Union[Iterable[str], str] = 'smiles'
) -> List[str]:
strings = _clean_split_input(strings)
converted = convert_string_representation(
strings=strings,
input_representation=input_representation,
output_representation=output_representation,
)
return {key: list(map(str, cast(val, to=list))) for key, val in converted.items()}
def convert_one(
strings: str,
input_representation: str = 'smiles',
output_representation: Union[Iterable[str], str] = 'smiles'
):
output_representation = cast(output_representation, to=list)
for rep in output_representation:
message = f"Converting from {input_representation} to {rep}..."
gr.Info(message, duration=10)
df = pd.DataFrame({
input_representation: _clean_split_input(strings),
})
return convert_file(
df=df,
column=input_representation,
input_representation=input_representation,
output_representation=output_representation,
)
def _prediction_loop(
df: pd.DataFrame,
predict: Union[Iterable[str], str] = 'smiles',
extra_metrics: Optional[Union[Iterable[str], str]] = None
) -> pd.DataFrame:
species_to_predict = cast(predict, to=list)
prediction_cols = []
if extra_metrics is None:
extra_metrics = []
else:
extra_metrics = cast(extra_metrics, to=list)
for species in species_to_predict:
message = f"Predicting for species: {species}"
print_err(message)
gr.Info(message, duration=3)
this_modelbox = MODELBOXES[species]
this_features = this_modelbox._input_cols
this_labels = this_modelbox._label_cols
this_prediction_input = (
df
.rename(columns={
"smiles": this_features[0],
})
.assign(**{label: np.nan for label in this_labels})
)
print(this_prediction_input)
prediction = this_modelbox.predict(
data=this_prediction_input,
features=this_features,
labels=this_labels,
aggregator="mean",
cache=CACHE,
).with_format("numpy")["__prediction__"].flatten()
print(prediction)
this_col = f"{species}: predicted MIC (µM)"
df[this_col] = np.power(10., -prediction) * 1e6
prediction_cols.append(this_col)
this_col = f"{species}: predicted MIC (µg / mL)"
df[this_col] = np.power(10., -prediction) * 1e3 * df["mwt"]
prediction_cols.append(this_col)
for extra_metric in extra_metrics:
message = f"Calculating {extra_metric} for species: {species}"
print_err(message)
gr.Info(message, duration=10)
# this_modelbox._input_training_data = this_modelbox._input_training_data.remove_columns([this_modelbox._in_key])
this_col = f"{species}: {extra_metric}"
prediction_cols.append(this_col)
print(">>>", this_modelbox._input_training_data)
print(">>>", this_modelbox._input_training_data.format)
print(">>>", this_modelbox._in_key, this_modelbox._out_key)
this_extra = (
EXTRA_METRICS[extra_metric](
this_modelbox,
this_prediction_input,
)
.with_format("numpy")
)
df[this_col] = this_extra[this_extra.column_names[-1]]
return prediction_cols, df
def predict_one(
strings: str,
input_representation: str = 'smiles',
predict: Union[Iterable[str], str] = 'smiles',
extra_metrics: Optional[Union[Iterable[str], str]] = None
):
prediction_df = convert_one(
strings=strings,
input_representation=input_representation,
output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
)
prediction_cols, prediction_df = _prediction_loop(
prediction_df,
predict=predict,
extra_metrics=extra_metrics,
)
return gr.DataFrame(
prediction_df[
['id', 'pubchem_name', 'pubchem_id']
+ prediction_cols
+ ['smiles', 'inchikey', "mwt", "clogp"]
],
visible=True
)
def convert_file(
df: pd.DataFrame,
column: str = 'smiles',
input_representation: str = 'smiles',
output_representation: Union[str, Iterable[str]] = 'smiles'
):
output_representation = cast(output_representation, to=list)
for rep in output_representation:
message = f"Converting from {input_representation} to {rep}..."
gr.Info(message, duration=10)
print_err(df.head())
print_err(message)
gr.Info(message, duration=3)
errors, df = converter(
df=df,
column=column,
input_representation=input_representation,
output_representation=output_representation,
)
df = df[
cast(output_representation, to=list) +
[col for col in df if col not in output_representation]
]
all_err = sum(err for key, err in errors.items())
message = (
f"Converted {df.shape[0]} molecules from "
f"{input_representation} to {output_representation} "
f"with {all_err} errors!"
)
print_err(message)
gr.Info(message, duration=5)
return df
def predict_file(
df: pd.DataFrame,
column: str = 'smiles',
input_representation: str = 'smiles',
predict: str = 'smiles',
predict2: Optional[str] = None,
extra_metrics: Optional[Union[Iterable[str], str]] = None
):
predict = cast(predict, to=list)
if predict2 is not None:
predict += cast(predict2, to=list)
if extra_metrics is None:
extra_metrics = []
else:
extra_metrics = cast(extra_metrics, to=list)
if df.shape[0] > MAX_ROWS:
message = f"Truncating input to {MAX_ROWS} rows"
print_err(message)
gr.Info(message, duration=15)
df = df.iloc[:MAX_ROWS]
prediction_df = convert_file(
df,
column=column,
input_representation=input_representation,
output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
)
prediction_cols, prediction_df = _prediction_loop(
prediction_df,
predict=predict,
extra_metrics=extra_metrics,
)
main_cols = set(
['id', 'inchikey', 'smiles', "mwt", "clogp"]
+ [column]
+ prediction_cols
)
other_cols = [
col for col in prediction_df
if col not in main_cols
]
return prediction_df[
['id', 'inchikey']
+ [column]
+ prediction_cols + other_cols
+ ['smiles', "mwt", "clogp"]
]
def draw_one(
strings: Union[Iterable[str], str],
input_representation: str = 'smiles'
):
message = f"Drawing {len(cast(strings, to=list))} molecules..."
gr.Info(message, duration=10)
_ids = _convert_input(
strings,
input_representation,
["inchikey", "id", "pubchem_name"],
)
mols = cast(_x2mol(_clean_split_input(strings), input_representation), to=list)
if isinstance(mols, Mol):
mols = [mols]
return Draw.MolsToGridImage(
mols,
molsPerRow=min(3, len(mols)),
subImgSize=(450, 450),
legends=["\n".join(items) for items in zip(*_ids.values())],
)
def log10_if_all_positive(df, col):
if np.all(df[col] > 0.):
df[col] = np.log10(df[col])
title = f"log10[ {col} ]"
else:
title = col
return title, df
def plot_x_vs_y(
df,
x: str,
y: str,
color: Optional[str] = None,
):
message = f"Plotting x={x}, y={y}, color={color}..."
gr.Info(message, duration=10)
print_err(df.head())
y_title = y
cols = ["id", "inchikey", "smiles", "mwt", "clogp", x, y]
if color is not None and color not in cols:
cols.append(color)
cols = list(set(cols))
x_title, df = log10_if_all_positive(df, x)
y_title, df = log10_if_all_positive(df, y)
color_title, df = log10_if_all_positive(df, color)
return gr.ScatterPlot(
value=df[cols],
x=x,
y=y,
color=color,
x_title=x_title,
y_title=y_title,
color_title=color_title,
tooltip="all",
visible=True,
)
def plot_pred_vs_observed(
df,
species: str,
observed: str,
color: Optional[str] = None,
):
print_err(df.head())
xcol = f"{species}: predicted MIC (µM)"
ycol = observed
return plot_x_vs_y(
df,
x=xcol,
y=ycol,
color=color,
)
def download_table(
df: pd.DataFrame
) -> str:
df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
filename = f"predicted-{df_hash}.csv"
df.to_csv(filename, index=False)
return gr.DownloadButton(value=filename, visible=True)
with gr.Blocks() as demo:
with open(HEADER_FILE, 'r') as f:
header_md = f.read()
gr.Markdown(header_md)
with gr.Tab(label="Paste one per line"):
input_format_single = gr.Dropdown(
label="Input string format",
choices=list(_FROM_FUNCTIONS),
value="smiles",
interactive=True,
)
input_line = gr.Textbox(
label="Input",
placeholder="Paste your molecule here, one per line",
lines=2,
interactive=True,
submit_btn=True,
)
output_species_single = gr.CheckboxGroup(
label="Species for prediction",
choices=list(MODEL_REPOS),
value=list(MODEL_REPOS)[:1],
interactive=True,
)
extra_metric = gr.CheckboxGroup(
label="Extra metrics (Doubscore & Information Sensitivity can increase calculation time to a couple of minutes!)",
choices=list(EXTRA_METRICS),
value=list(EXTRA_METRICS)[:2],
interactive=True,
)
examples = gr.Examples(
examples=[
[
'\n'.join([
"C1CC1N2C=C(C(=O)C3=CC(=C(C=C32)N4CCNCC4)F)C(=O)O",
"CN1C(=NC(=O)C(=O)N1)SCC2=C(N3[C@@H]([C@@H](C3=O)NC(=O)/C(=N\OC)/C4=CSC(=N4)N)SC2)C(=O)O",
"CC(C)(C(=O)O)O/N=C(/C1=CSC(=N1)N)\C(=O)N[C@H]2[C@@H]3N(C2=O)C(=C(CS3)C[N+]4(CCCC4)CCNC(=O)C5=C(C(=C(C=C5)O)O)Cl)C(=O)[O-]",
"CC(=O)NC[C@H]1CN(C(=O)O1)C2=CC(=C(C=C2)N3CCOCC3)F",
"C1CC2=CC(=NC=C2OC1)CNC3CCN(CC3)C[C@@H]4CN5C(=O)C=CC6=C5N4C(=O)C=N6",
]),
"Yersinia pestis",
list(EXTRA_METRICS)[:2],
], # cipro, ceftriaxone, cefiderocol, linezolid, gepotidacin
[
'\n'.join([
"C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
"CC1([C@@H](N2[C@H](S1)[C@@H](C2=O)NC(=O)[C@@H](C3=CC=CC=C3)N)C(=O)O)C",
"CC1([C@@H](N2[C@H](S1)[C@@H](C2=O)NC(=O)[C@@H](C3=CC=C(C=C3)O)N)C(=O)O)C",
"C[C@@H]1[C@@H]2[C@H](C(=O)N2C(=C1S[C@H]3C[C@H](NC3)C(=O)N(C)C)C(=O)O)[C@@H](C)O",
"C[C@@]1([C@H]2C[C@H]3[C@@H](C(=O)C(=C([C@]3(C(=O)C2=C(C4=C1C=CC=C4O)O)O)O)C(=O)N)N(C)C)O",
"CC1=C2C=CC=C(C2=C(C3=C1C[C@H]4[C@@H](C(=O)C(=C([C@]4(C3=O)O)O)C(=O)N)N(C)C)O)O",
]),
"Staphylococcus aureus",
list(EXTRA_METRICS)[:2],
], # doxorubicin, ampicillin, amoxicillin, meropenem, tetracycline, anhydrotetracycline
[
'\n'.join([
"C1=C(SC(=N1)SC2=NN=C(S2)N)[N+](=O)[O-]",
"C1CN(CCC12C3=CC=CC=C3NC(=O)O2)CCC4=CC=C(C=C4)C(F)(F)F",
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
"CC1=CC(=NO1)NS(=O)(=O)C2=CC=C(C=C2)N",
"C1[C@@H]([C@H]([C@@H]([C@H]([C@@H]1NC(=O)[C@H](CCN)O)O[C@@H]2[C@@H]([C@H]([C@@H]([C@H](O2)CO)O)N)O)O)O[C@@H]3[C@@H]([C@H]([C@@H]([C@H](O3)CN)O)O)O)N",
"C1=CN=CC=C1C(=O)NN",
]),
["Escherichia coli", "Acinetobacter baumannii"],
list(EXTRA_METRICS)[:2],
], # Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid
[
'\n'.join([
"CC[C@H](C)[C@H]1C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@@H](C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@H](C(=O)N2CCC[C@@H]2C(=O)N3CCC[C@H]3C(=O)N[C@H](C(=O)N[C@H](C(=O)N1)CC4=CNC5=CC=CC=C54)[C@@H](C)O)CO)C)CCN)CCN)CC6=CNC7=CC=CC=C76)CCN)CCN)CCCN)CCN",
"C[C@H]1[C@H]([C@@](C[C@@H](O1)O[C@@H]2[C@H]([C@@H]([C@H](O[C@H]2OC3=C4C=C5C=C3OC6=C(C=C(C=C6)[C@H]([C@H](C(=O)N[C@H](C(=O)N[C@H]5C(=O)N[C@@H]7C8=CC(=C(C=C8)O)C9=C(C=C(C=C9O)O)[C@H](NC(=O)[C@H]([C@@H](C1=CC(=C(O4)C=C1)Cl)O)NC7=O)C(=O)O)CC(=O)N)NC(=O)[C@@H](CC(C)C)NC)O)Cl)CO)O)O)(C)N)O",
"CN1[C@H](C(=O)NCC2=C(C=CC=C2SC3=C(CN[C@H](C(=O)N[C@H](C1=O)CCCCN)CCCN)C=CC=N3)C4=CC=C(C=C4)C(=O)O)CC5=CNC6=CC=CC=C65",
"C[C@@]1(CO[C@@H]([C@@H]([C@H]1NC)O)O[C@H]2[C@@H](C[C@@H]([C@H]([C@@H]2O)O[C@@H]3[C@@H](CC=C(O3)CNCCO)N)N)NC(=O)[C@H](CCN)O)O",
"CC(C1CCC(C(O1)OC2C(CC(C(C2O)OC3C(C(C(CO3)(C)O)NC)O)N)N)N)NC",
"C[C@H]1/C=C/C=C(\C(=O)NC2=C(C(=C3C(=C2O)C(=C(C4=C3C(=O)[C@](O4)(O/C=C/[C@@H]([C@H]([C@H]([C@@H]([C@@H]([C@@H]([C@H]1O)C)O)C)OC(=O)C)C)OC)C)C)O)O)/C=N/N5CCN(CC5)C)/C",
]),
"Acinetobacter baumannii",
list(EXTRA_METRICS)[:2],
], # murepavadin, vancomycin, zosurabalpin, plazomicin, Gentamicin, rifampicin
[
'\n'.join([
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)CC4)N=C3",
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@@H](C4)N)N=C3",
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@H](CC4)[NH3+])N=C3.[Cl-]",
"C1=C(C(=O)NC(=O)N1)F",
"CCCCCCNC(=O)N1C=C(C(=O)NC1=O)F",
"C[C@@H]1OC[C@@H]2[C@@H](O1)[C@@H]([C@H]([C@@H](O2)O[C@H]3[C@H]4COC(=O)[C@@H]4[C@@H](C5=CC6=C(C=C35)OCO6)C7=CC(=C(C(=C7)OC)O)OC)O)O",
]),
"Escherichia coli",
list(EXTRA_METRICS)[:2],
], # Debio1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide
[
'\n'.join([
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
"CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
"C1=CC(=CC=C1CCC2=CNC3=C2C(=O)NC(=N3)N)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
"CC1=C(C2=C(C=C1)N=C(NC2=O)N)SC3=CC=NC=C3",
"CN(CC1=CN=C2C(=N1)C(=NC(=N2)N)N)C3=CC=C(C=C3)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
"CC1=NC2=C(C=C(C=C2)CN(C)C3=CC=C(S3)C(=O)N[C@@H](CCC(=O)O)C(=O)O)C(=O)N1",
]),
"Klebsiella pneumoniae",
list(EXTRA_METRICS)[:2],
], # Trimethoprim, SCH79797, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed
[
'\n'.join([
"C[C@H]([C@@H](C(=O)NO)NC(=O)C1=CC=C(C=C1)C#CC2=CC=C(C=C2)CN3CCOCC3)O",
"CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
"C1=CC=C(C=C1)CNC2=NC(=NC3=CC=CC=C32)NCC4=CC=CC=C4",
"CC(C)(C)C1=CC=C(C=C1)C(=O)NC(=S)NC2=CC=C(C=C2)NC(=O)CCCCN(C)C",
"CCC1=C(C(=NC(=N1)N)N)C2=CC=C(C=C2)Cl",
"C1=CC(=CC=C1C(=O)N[C@@H](CCC(=O)O)C(=O)O)NCC2=CN=C3C(=N2)C(=NC(=N3)N)N",
]),
"Klebsiella pneumoniae",
list(EXTRA_METRICS)[:2],
], # CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin
],
example_labels=[
"_Y. pestis_ (plague) vs Ciprofloxacin, Ceftriaxone, Cefiderocol, Linezolid, Gepotidacin",
"_S. aureus_ vs Doxorubicin, Ampicillin, Amoxicillin, Meropenem, Tetracycline, Anhydrotetracycline",
"_E. coli_ and _A. baumannii_ vs Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid",
"_A. baumannii_ vs Murepavadin, Vancomycin, Zosurabalpin, Plazomicin, Gentamicin, Rifampicin",
"_E. coli_ vs Debio-1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide",
"_K. pneumoniae_ vs Trimethoprim, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed",
"_K. pneumoniae_ vs CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin"
],
inputs=[input_line, output_species_single, extra_metric],
cache_mode="eager",
)
download_single = gr.DownloadButton(
label="Download predictions",
visible=False,
)
# with gr.Row():
output_line = gr.DataFrame(
label="Predictions",
interactive=False,
visible=False,
)
drawing = gr.Image(label="Chemical structures")
gr.on(
[
input_line.submit,
],
fn=predict_one,
inputs=[
input_line,
input_format_single,
output_species_single,
extra_metric,
],
outputs={
output_line,
}
).then(
draw_one,
inputs=[
input_line,
input_format_single,
],
outputs=drawing,
).then(
download_table,
inputs=output_line,
outputs=download_single
)
with gr.Tab(f"Predict on structures from a file (max. {MAX_ROWS} rows, ≤ 2 species)"):
input_file = gr.File(
label="Upload a table of chemical compounds here",
file_types=[".xlsx", ".csv", ".tsv", ".txt"],
)
with gr.Row():
input_column = gr.Dropdown(
label="Input column name",
choices=[],
allow_custom_value=True,
visible=False,
)
input_format = gr.Dropdown(
label="Input string format",
choices=list(_FROM_FUNCTIONS),
value="smiles",
interactive=True,
visible=True,
)
output_species = [
gr.Dropdown(
label="Species 1 for prediction",
choices=list(MODEL_REPOS),
value=list(MODEL_REPOS)[0],
interactive=True,
),
gr.Dropdown(
label="Species 2 for prediction",
choices=list(MODEL_REPOS),
value=None,
interactive=True,
),
]
extra_metric_file = gr.CheckboxGroup(
label="Extra metrics (Information Sensitivity can increase calculation time)",
choices=list(EXTRA_METRICS),
value=list(EXTRA_METRICS)[:2],
interactive=True,
)
go_button2 = gr.Button(
value="Predict!",
)
download = gr.DownloadButton(
label="Download predictions",
visible=False,
)
input_data = gr.Dataframe(
label="Input data",
max_height=500,
visible=False,
interactive=False,
)
with gr.Row():
observed_col = gr.Dropdown(
label="Observed column (y-axis) for left plot",
choices=[],
value=None,
interactive=True,
visible=False,
)
color_col = gr.Dropdown(
label="Color for left plot",
choices=[],
value=None,
interactive=True,
visible=False,
)
with gr.Row():
any_x_col = gr.Dropdown(
label="x-axis for right plot",
choices=[],
value=None,
interactive=True,
visible=False,
)
any_y_col = gr.Dropdown(
label="y-axis for right plot",
choices=[],
value=None,
interactive=True,
visible=False,
)
any_color_col = gr.Dropdown(
label="Color for right plot",
choices=[],
value=None,
interactive=True,
visible=False,
)
plot_button = gr.Button(
value="Plot!",
visible=False,
)
file_examples = gr.Examples(
examples=[
[
"example-data/stokes2020-eco.csv",
"SMILES",
"Escherichia coli",
"Mean_Growth",
"Escherichia coli: Doubtscore",
list(EXTRA_METRICS)[:3],
],
[
"example-data/liu23-abau.csv",
"SMILES",
"Acinetobacter baumannii",
"Mean",
"Acinetobacter baumannii: Doubtscore",
list(EXTRA_METRICS)[:3],
],
[
"example-data/wong24-sau-tox-5000.csv",
"SMILES",
"Staphylococcus aureus",
"Mean",
"Staphylococcus aureus: Doubtscore",
list(EXTRA_METRICS)[:3],
],
],
example_labels=[
"E. coli training data from Stokes J. et al., Cell, 2020",
"A. baumannii training data from Liu, 2023",
"S. aureus and toxicity training data from Wong, 2024",
],
inputs=[input_file, input_column, output_species[0], observed_col, color_col, extra_metric_file],
cache_mode="eager",
)
with gr.Row():
pred_vs_observed = gr.ScatterPlot(
label="Prediction vs observed",
x_title="Predicted MIC (µM)",
y_title="Observed",
visible=False,
height=600,
)
plot_any_vs_any = gr.ScatterPlot(
label="Any vs any",
visible=False,
height=600,
)
load_data_action = {
"fn": load_input_data,
"inputs": [input_file],
"outputs": [input_data, input_column]
}
file_examples.load_input_event.then(
**load_data_action,
)
input_file.upload(
**load_data_action,
)
go2_click_event = go_button2.click(
predict_file,
inputs=[
input_data,
input_column,
input_format,
*output_species,
extra_metric_file,
],
outputs={
input_data,
}
).then(
download_table,
inputs=input_data,
outputs=download
).then(
lambda: gr.Button(visible=True),
outputs=[plot_button]
)
for dropdown in [observed_col, color_col, any_color_col, any_x_col, any_y_col]:
go2_click_event.then(
partial(get_dropdown_options, _type="number"),
inputs=[input_data],
outputs=[dropdown],
)
plot_button.click(
plot_pred_vs_observed,
inputs=[
input_data,
output_species[0],
observed_col,
color_col,
],
outputs=[pred_vs_observed],
).then(
plot_x_vs_y,
inputs=[
input_data,
any_x_col,
any_y_col,
any_color_col,
],
outputs=[plot_any_vs_any],
)
if __name__ == "__main__":
demo.queue()
demo.launch(share=True)