mic-predict / app.py
Eachan Johnson
Update examples, fix plot issue missing smiles column
36929a8
"""Gradio demo for schemist."""
from typing import Iterable, List, Optional, Union
import csv
from io import TextIOWrapper
import itertools
import json
import os
import sys
csv.field_size_limit(sys.maxsize)
from carabiner import cast, print_err
from carabiner.pd import read_table
from duvida.autoclass import AutoModelBox
import gradio as gr
import nemony as nm
import numpy as np
import pandas as pd
from rdkit.Chem import Draw, Mol
from schemist.converting import (
_FROM_FUNCTIONS,
convert_string_representation,
_x2mol,
)
from schemist.tables import converter
import torch
from duvida.stateless.config import config
THEME = gr.themes.Default()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CACHE = "./cache"
MAX_ROWS = 1000
BATCH_SIZE = 16
HEADER_FILE = os.path.join("sources", "header.md")
with open("repos.json", "r") as f:
MODEL_REPOS = json.load(f)
MODELBOXES = {
key: AutoModelBox.from_pretrained(val, cache_dir=os.path.join(CACHE, "duvida"))
for key, val in MODEL_REPOS.items()
}
[mb.to(DEVICE) for mb in MODELBOXES.values()]
EXTRA_METRICS = {
"log10(variance)": lambda modelbox, candidates: modelbox.prediction_variance(candidates=candidates, batch_size=BATCH_SIZE, cache=CACHE).map(lambda x: {modelbox._variance_key: torch.log10(x[modelbox._variance_key])}),
"Tanimoto nearest neighbor to training data": lambda modelbox, candidates: modelbox.tanimoto_nn(candidates=candidates, batch_size=BATCH_SIZE),
"Doubtscore": lambda modelbox, candidates: modelbox.doubtscore(candidates=candidates, cache=CACHE, batch_size=BATCH_SIZE).map(lambda x: {"doubtscore": torch.log10(x["doubtscore"])}),
"Information sensitivity (approx.)": lambda modelbox, candidates: modelbox.information_sensitivity(candidates=candidates, batch_size=BATCH_SIZE, optimality_approximation=True, approximator="squared_jacobian", cache=CACHE).map(lambda x: {"information sensitivity": torch.log10(x["information sensitivity"])}),
}
with open(os.path.join("example-data", "examples.json"), "r") as f:
EXAMPLES = json.load(f)
def get_dropdown_options(df, _type = str):
if _type == str:
cols = list(df.select_dtypes(exclude=[np.number]))
else:
cols = list(df.select_dtypes([np.number]))
non_none = [col for col in cols if col is not None]
if len(cols) > 0:
default_value = non_none[0]
else:
default_value = ""
print_err(f"Dropdown default value is {default_value}")
return gr.Dropdown(
choices=cols,
interactive=True,
value=default_value,
visible=True,
allow_custom_value=True,
)
def load_input_data(file: Union[TextIOWrapper, str], return_pd: bool = False) -> pd.DataFrame:
file = file if isinstance(file, str) else file.name
print_err(f"Loading {file}")
df = read_table(file, nrows=MAX_ROWS)
print_err(df.head())
if return_pd:
return (df, gr.Dataframe(value=df, visible=True)), get_dropdown_options(df, str)
else:
return gr.Dataframe(value=df, visible=True), get_dropdown_options(df, str)
def _clean_split_input(strings: str) -> List[str]:
return [
s2.split(":")[-1].strip()
for s in strings.split("\n")
for s2 in s.split(",")
]
def _convert_input(
strings: str,
input_representation: str = 'smiles',
output_representation: Union[Iterable[str], str] = 'smiles'
) -> List[str]:
strings = _clean_split_input(strings)
converted = convert_string_representation(
strings=strings,
input_representation=input_representation,
output_representation=output_representation,
)
return {key: list(map(str, cast(val, to=list))) for key, val in converted.items()}
def convert_one(
strings: str,
input_representation: str = 'smiles',
output_representation: Union[Iterable[str], str] = 'smiles',
):
output_representation = cast(output_representation, to=list)
for rep in output_representation:
message = f"Converting from {input_representation} to {rep}..."
gr.Info(message, duration=3)
df = pd.DataFrame({
input_representation: _clean_split_input(strings),
})
return convert_file(
df=df,
column=input_representation,
input_representation=input_representation,
output_representation=output_representation,
)
def _prediction_loop(
df: pd.DataFrame,
predict: Union[Iterable[str], str] = 'smiles',
extra_metrics: Optional[Union[Iterable[str], str]] = None
) -> pd.DataFrame:
species_to_predict = cast(predict, to=list)
prediction_cols = []
if extra_metrics is None:
extra_metrics = []
else:
extra_metrics = cast(extra_metrics, to=list)
for species in species_to_predict:
message = f"Predicting for species: {species}"
print_err(message)
gr.Info(message, duration=3)
this_modelbox = MODELBOXES[species]
this_features = this_modelbox._input_cols
this_labels = this_modelbox._label_cols
this_prediction_input = (
df
.rename(columns={
"smiles": this_features[0],
})
.assign(**{label: np.nan for label in this_labels})
)
print(this_prediction_input)
prediction = this_modelbox.predict(
data=this_prediction_input,
features=this_features,
labels=this_labels,
aggregator="mean",
cache=CACHE,
).with_format("numpy")["__prediction__"].flatten()
print(prediction)
this_col = f"{species}: predicted MIC (µM)"
df[this_col] = np.power(10., -prediction) * 1e6
prediction_cols.append(this_col)
this_col = f"{species}: predicted MIC (µg / mL)"
df[this_col] = np.power(10., -prediction) * 1e3 * df["mwt"]
prediction_cols.append(this_col)
for extra_metric in extra_metrics:
message = f"Calculating {extra_metric} for species: {species}"
print_err(message)
gr.Info(message, duration=10)
# this_modelbox._input_training_data = this_modelbox._input_training_data.remove_columns([this_modelbox._in_key])
this_col = f"{species}: {extra_metric}"
prediction_cols.append(this_col)
print(">>>", this_modelbox._input_training_data)
print(">>>", this_modelbox._input_training_data.format)
print(">>>", this_modelbox._in_key, this_modelbox._out_key)
this_extra = (
EXTRA_METRICS[extra_metric](
this_modelbox,
this_prediction_input,
)
.with_format("numpy")
)
df[this_col] = this_extra[this_extra.column_names[-1]]
return prediction_cols, df
def predict_one(
strings: str,
input_representation: str = 'smiles',
predict: Union[Iterable[str], str] = 'smiles',
extra_metrics: Optional[Union[Iterable[str], str]] = None,
return_pd: bool = False
):
prediction_df = convert_one(
strings=strings,
input_representation=input_representation,
output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
)
prediction_cols, prediction_df = _prediction_loop(
prediction_df,
predict=predict,
extra_metrics=extra_metrics,
)
df = prediction_df[
['id', 'pubchem_name', 'pubchem_id']
+ prediction_cols
+ ['smiles', 'inchikey', "mwt", "clogp"]
]
gradio_opts = {
"label": "Predictions",
"value": df,
"pinned_columns": 3,
"visible": True,
"wrap": True,
"column_widths": [120] * 3 + [250] * (prediction_df.shape[1] - 3),
}
if return_pd:
return df, gr.DataFrame(**gradio_opts)
else:
return gr.DataFrame(**gradio_opts)
def convert_file(
df: pd.DataFrame,
column: str = 'smiles',
input_representation: str = 'smiles',
output_representation: Union[str, Iterable[str]] = 'smiles'
):
output_representation = cast(output_representation, to=list)
message = f"Converting from {input_representation} to {', '.join(output_representation)}..."
gr.Info(message, duration=5)
print_err(message)
print_err(df.head())
errors, df = converter(
df=df,
column=column,
input_representation=input_representation,
output_representation=output_representation,
)
df = df[
output_representation +
[col for col in df if col not in output_representation]
]
all_err = sum(err for key, err in errors.items())
message = (
f"Converted {df.shape[0]} molecules from "
f"{input_representation} to {output_representation} "
f"with {all_err} errors!"
)
print_err(message)
gr.Info(message, duration=5)
return df
def predict_file(
df: pd.DataFrame,
column: str = 'smiles',
input_representation: str = 'smiles',
predict: str = 'smiles',
predict2: Optional[str] = None,
extra_metrics: Optional[Union[Iterable[str], str]] = None,
return_pd: bool = False
):
predict = cast(predict, to=list)
if predict2 is not None and predict2 in MODELBOXES:
predict += cast(predict2, to=list)
if extra_metrics is None:
extra_metrics = []
else:
extra_metrics = cast(extra_metrics, to=list)
if df.shape[0] > MAX_ROWS:
message = f"Truncating input to {MAX_ROWS} rows"
print_err(message)
gr.Info(message, duration=15)
df = df.iloc[:MAX_ROWS]
prediction_df = convert_file(
df,
column=column,
input_representation=input_representation,
output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
)
prediction_cols, prediction_df = _prediction_loop(
prediction_df,
predict=predict,
extra_metrics=extra_metrics,
)
left_cols = ['id', 'inchikey']
end_cols = ["smiles", "mwt", "clogp"]
main_cols = set(
left_cols
+ end_cols
+ [column]
+ prediction_cols
)
other_cols = list(set(prediction_df) - main_cols)
return_cols = (
left_cols
+ [column]
+ prediction_cols
+ other_cols
+ end_cols
)
deduplicated_cols = []
for col in return_cols:
if not col in deduplicated_cols:
deduplicated_cols.append(col)
prediction_df = prediction_df[deduplicated_cols]
plot_dropdown = get_dropdown_options(prediction_df, _type="number")
plot_dropdown = (plot_dropdown,) * 5
gradio_opts = {
"label": "Predictions",
"value": prediction_df,
"pinned_columns": 3,
"visible": True,
"wrap": True,
"column_widths": [120] * 3 + [250] * (prediction_df.shape[1] - 3),
}
if return_pd:
return ((prediction_df, gr.Dataframe(**gradio_opts)),) + (plot_dropdown,)
else:
return (gr.Dataframe(**gradio_opts),) + (plot_dropdown,)
def draw_one(
df,
smiles_col: str = "smiles",
legends: Optional[Union[str, Iterable[str]]] = None
):
if legends is None:
legends = ["inchikey", "id", "pubchem_name"]
else:
legends = []
message = f"Drawing {df.shape[0]} molecules..."
gr.Info(message, duration=2)
_ids = {col: df[col].tolist() for col in legends}
mols = cast(_x2mol(df[smiles_col], "smiles"), to=list)
if isinstance(mols, Mol):
mols = [mols]
return Draw.MolsToGridImage(
mols,
molsPerRow=min(5, len(mols)),
subImgSize=(600, 600),
legends=[
"\n".join(
_x if _x is not None else ""
for _x in items
) for items in zip(*_ids.values())
],
)
def log10_if_all_positive(df, col):
if np.all(df[col] > 0.):
df[col] = np.log10(df[col])
title = f"log10[ {col} ]"
else:
title = col
return title, df
def plot_x_vs_y(
df,
x: str,
y: str,
color: Optional[str] = None,
):
message = f"Plotting x={x}, y={y}, color={color}..."
gr.Info(message, duration=10)
print_err(df.head())
y_title = y
cols = ["id", "inchikey", "smiles", "mwt", "clogp", x, y]
if color is not None and color not in cols:
cols.append(color)
cols = list(set(cols))
x_title, df = log10_if_all_positive(df, x)
y_title, df = log10_if_all_positive(df, y)
color_title, df = log10_if_all_positive(df, color)
return gr.ScatterPlot(
value=df[cols],
x=x,
y=y,
color=color,
x_title=x_title,
y_title=y_title,
color_title=color_title,
tooltip="all",
visible=True,
)
def plot_pred_vs_observed(
df,
species: str,
observed: str,
color: Optional[str] = None,
):
print_err(df.head())
xcol = f"{species}: predicted MIC (µM)"
ycol = observed
return plot_x_vs_y(
df,
x=xcol,
y=ycol,
color=color,
)
def download_table(
df: pd.DataFrame
) -> str:
df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
filename = os.path.join(CACHE, "downloads", f"predicted-{df_hash}.csv")
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
df.to_csv(filename, index=False)
return gr.DownloadButton(value=filename, visible=True)
def _predict_then_draw_then_download(
strings: str,
input_representation: str = 'smiles',
predict: Union[Iterable[str], str] = 'smiles',
extra_metrics: Optional[Union[Iterable[str], str]] = None,
smiles_col: str = "smiles",
legends: Optional[Union[str, Iterable[str]]] = None
):
df, gr_df = predict_one(
strings=strings,
input_representation=input_representation,
predict=predict,
extra_metrics=extra_metrics,
return_pd=True,
)
img = draw_one(
df,
smiles_col="smiles",
)
return gr_df, img, download_table(df)
def _load_then_predict_then_download_then_reveal_plot(
file: str,
column: str = 'smiles',
input_representation: str = 'smiles',
predict: str = 'smiles',
predict2: Optional[str] = "",
extra_metrics: Optional[Union[Iterable[str], str]] = None
):
(df, df_gr), col_opts = load_input_data(
file,
return_pd=True,
)
(df, df_gr), plot_opts = predict_file(
df,
column=column,
input_representation=input_representation,
predict=predict,
predict2=None if predict2 == "" else predict2,
extra_metrics=extra_metrics,
return_pd=True,
)
print_err(df.head())
return (
df_gr,
download_table(df),
) + plot_opts
def _initial_setup():
"""Set up blocks.
"""
print_err(f"Duvida config is {config}")
print_err(f"Default torch device is {DEVICE}")
line_inputs = {
"format": gr.Dropdown(
label="Input string format",
choices=list(_FROM_FUNCTIONS),
value="smiles",
interactive=True,
),
"species": gr.CheckboxGroup(
label="Species for prediction",
choices=list(MODEL_REPOS),
value=list(MODEL_REPOS)[:1],
interactive=True,
),
"extras": gr.CheckboxGroup(
label="Extra metrics (Doubtscore & Information Sensitivity can increase calculation time to a couple of minutes!)",
choices=list(EXTRA_METRICS),
value=list(EXTRA_METRICS)[:2],
interactive=True,
),
"strings": gr.Textbox(
label="Input",
placeholder="Paste your molecule here, one per line.",
lines=2,
interactive=True,
submit_btn=True,
),
}
output_line = gr.DataFrame(
label="Predictions (scroll left and right)",
interactive=False,
visible=True,
)
download_single = gr.DownloadButton(
label="Download predictions",
visible=True,
)
drawing = gr.Image(label="Chemical structures")
file_inputs = {
"file": gr.File(
label="Upload a table of chemical compounds here",
file_types=[".xlsx", ".csv", ".tsv", ".txt"],
),
"column": gr.Dropdown(
label="Input column name",
choices=[],
allow_custom_value=True,
visible=True,
interactive=True,
),
"format": gr.Dropdown(
label="Input string format",
choices=list(_FROM_FUNCTIONS),
value="smiles",
interactive=True,
visible=True,
),
"species": [
gr.Dropdown(
label="Species 1 for prediction",
choices=list(MODEL_REPOS),
value=list(MODEL_REPOS)[0],
interactive=True,
allow_custom_value=True,
),
gr.Dropdown(
label="Species 2 for prediction",
choices=list(MODEL_REPOS),
value=None,
interactive=True,
allow_custom_value=True,
),
],
"extras": gr.CheckboxGroup(
label="Extra metrics (Information Sensitivity can increase calculation time)",
choices=list(EXTRA_METRICS),
value=list(EXTRA_METRICS)[:2],
interactive=True,
),
}
input_dataframe = gr.Dataframe(
label="Input data",
max_height=500,
visible=True,
interactive=False,
show_fullscreen_button=True,
show_search="filter",
max_chars=45,
)
download = gr.DownloadButton(
label="Download predictions",
visible=False,
)
plot_button = gr.Button(
value="Plot!",
visible=False,
)
left_plot_inputs = {
"observed": gr.Dropdown(
label="Observed column (y-axis) for left plot",
choices=[],
value=None,
interactive=True,
visible=True,
allow_custom_value=True,
),
"color": gr.Dropdown(
label="Color for left plot",
choices=[],
value=None,
interactive=True,
visible=True,
allow_custom_value=True,
)
}
right_plot_inputs = {
"x": gr.Dropdown(
label="x-axis for right plot",
choices=[],
value=None,
interactive=True,
visible=True,
allow_custom_value=True,
),
"y": gr.Dropdown(
label="y-axis for right plot",
choices=[],
value=None,
interactive=True,
visible=True,
allow_custom_value=True,
),
"color": gr.Dropdown(
label="Color for right plot",
choices=[],
value=None,
interactive=True,
visible=True,
allow_custom_value=True,
)
}
plots = {
"left": gr.ScatterPlot(
height=500,
visible=False,
),
"right": gr.ScatterPlot(
height=500,
visible=False,
),
}
return (
line_inputs,
output_line,
download_single,
drawing,
file_inputs,
input_dataframe,
download,
plot_button,
left_plot_inputs,
right_plot_inputs,
plots,
)
if __name__ == "__main__":
(
line_inputs,
output_line,
download_single,
drawing,
file_inputs,
input_dataframe,
download,
plot_button,
left_plot_inputs,
right_plot_inputs,
plots,
) = _initial_setup()
with gr.Blocks(theme=THEME) as demo:
with open(HEADER_FILE, 'r') as f:
header_md = f.read()
gr.Markdown(header_md)
with gr.Tab(label="Paste one per line"):
examples = gr.Examples(
examples=[
[
"\n".join(eg["strings"]),
"smiles",
eg["species"],
list(EXTRA_METRICS)[:3],
]
for eg in EXAMPLES["line input examples"]
],
example_labels=[
eg["label"] for eg in EXAMPLES["line input examples"]
],
examples_per_page=100,
inputs=[
line_inputs["strings"],
line_inputs["format"],
line_inputs["species"],
line_inputs["extras"],
],
fn=_predict_then_draw_then_download,
outputs=[
output_line,
drawing,
download_single,
],
cache_examples=True,
cache_mode="eager",
)
for val in line_inputs.values():
val.render()
# with gr.Row():
output_line.render()
download_single.render()
drawing.render()
line_inputs["strings"].submit(
fn=_predict_then_draw_then_download,
inputs=[
line_inputs["strings"],
line_inputs["format"],
line_inputs["species"],
line_inputs["extras"],
],
outputs=[
output_line,
drawing,
download_single,
],
)
with gr.Tab(f"Predict on structures from a file (max. {MAX_ROWS} rows, ≤ 2 species)"):
plot_dropdowns = list(itertools.chain(
left_plot_inputs.values(),
right_plot_inputs.values(),
))
file_examples = gr.Examples(
examples=[
[
eg["file"],
eg["column"],
"smiles",
eg["species"],
"",
list(EXTRA_METRICS)[:3],
] for eg in EXAMPLES["file examples"]
],
example_labels=[
eg["label"] for eg in EXAMPLES["file examples"]
],
fn=_load_then_predict_then_download_then_reveal_plot,
inputs=[
file_inputs["file"],
file_inputs["column"],
file_inputs["format"],
*file_inputs["species"],
file_inputs["extras"],
],
outputs=[
input_dataframe,
download,
*plot_dropdowns,
],
cache_examples=True, ## appears to cause CSV load error
cache_mode="eager",
)
file_inputs["file"].render()
with gr.Row():
for key in ("column", "format"):
file_inputs[key].render()
with gr.Row():
for item in file_inputs["species"]:
item.render()
file_inputs["extras"].render()
go_button2 = gr.Button(value="Predict!")
input_dataframe.render()
download.render()
with gr.Row():
for val in left_plot_inputs.values():
val.render()
with gr.Row():
for val in right_plot_inputs.values():
val.render()
plot_button.render()
with gr.Row():
for val in plots.values():
val.render()
file_inputs["file"].upload(
fn=load_input_data,
inputs=file_inputs["file"],
outputs=[
input_dataframe,
file_inputs["column"],
],
)
go2_click_event = go_button2.click(
_load_then_predict_then_download_then_reveal_plot,
inputs=[
file_inputs["file"],
file_inputs["column"],
file_inputs["format"],
*file_inputs["species"],
file_inputs["extras"],
],
outputs=[
input_dataframe,
download,
*plot_dropdowns,
],
scroll_to_output=True,
).then(
lambda: gr.Button(visible=True),
outputs=[plot_button],
js=True,
)
file_examples.load_input_event.then(
lambda: gr.Button(visible=True),
outputs=[plot_button],
js=True,
)
plot_button.click(
plot_pred_vs_observed,
inputs=[
input_dataframe,
file_inputs["species"][0],
left_plot_inputs["observed"],
left_plot_inputs["color"],
],
outputs=[plots["left"]],
scroll_to_output=True,
).then(
plot_x_vs_y,
inputs=[
input_dataframe,
right_plot_inputs["x"],
right_plot_inputs["y"],
right_plot_inputs["color"],
],
outputs=[plots["right"]],
)
demo.queue()
demo.launch(share=True)