Last commit not found
""" | |
This file defines the layout of the app including the header, sidebar, and tabs in the | |
main content area. | |
""" | |
#--------------------------------------------------------------------------------------- | |
# Imports | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from PIL import Image | |
import pandas as pd | |
import yaml | |
from src.data_preprocessing.create_descriptors import handle_inputs | |
from src.app.constants import (summary_text, | |
mhnfs_text, | |
citation_text, | |
few_shot_learning_text, | |
under_the_hood_text, | |
usage_text, | |
data_text, | |
trust_text, | |
example_trustworthy_text, | |
example_nottrustworthy_text) | |
#--------------------------------------------------------------------------------------- | |
# Global variables | |
MAX_INPUT_LENGTH = 20 | |
#--------------------------------------------------------------------------------------- | |
# Functions | |
class LayoutMaker(): | |
""" | |
This class includes all the design choices regarding the layout of the app. This | |
class can be used in the main file to define header, sidebar, and main content area. | |
""" | |
def __init__(self): | |
# Initialize the inputs dictionary | |
self.inputs = dict() # this will be the storage for query and support set inputs | |
self.inputs_lists = dict() | |
# Initialize prediction storage | |
self.predictions = None | |
# Buttons | |
self.buttons = dict() # this will be the storage for buttons | |
# content | |
self.summary_text = summary_text | |
self.mhnfs_text = mhnfs_text | |
self.citation_text = citation_text | |
self.few_shot_learning_text = few_shot_learning_text | |
self.under_the_hood_text = under_the_hood_text | |
self.usage_text = usage_text | |
self.data_text = data_text | |
self.trust_text = trust_text | |
self.example_trustworthy_text = example_trustworthy_text | |
self.example_nottrustworthy_text = example_nottrustworthy_text | |
self.df_trustworthy = pd.read_csv("./assets/example_csv/predictions/" | |
"trustworthy_example.csv") | |
self.df_nottrustworthy = pd.read_csv("./assets/example_csv/predictions/" | |
"nottrustworthy_example.csv") | |
self.max_input_length = MAX_INPUT_LENGTH | |
def make_sidebar(self): | |
""" | |
This function defines the sidebar of the app. It includes the logo, query box, | |
support set boxes, and predict buttons. | |
It returns the stored inputs (for query and support set) and the buttons which | |
allow for user interactions. | |
""" | |
with st.sidebar: | |
# Logo | |
logo = Image.open("./assets/logo.png") | |
st.image(logo) | |
st.divider() | |
# Query box | |
self._make_query_box() | |
st.divider() | |
# Support set actives box | |
self._make_active_support_set_box() | |
st.divider() | |
# Support set inactives box | |
self._make_inactive_support_set_box() | |
st.divider() | |
# Predict buttons | |
self.buttons["predict"] = st.button("Predict...") | |
self.buttons["reset"] = st.button("Reset") | |
return self.inputs, self.buttons | |
def make_header(self): | |
""" | |
This function defines the header of the app. It consists only of a png image | |
in which the title and an overview is given. | |
""" | |
header_container = st.container() | |
with header_container: | |
header = Image.open("./assets/header.png") | |
st.image(header) | |
def make_main_content_area(self, | |
predictor, | |
inputs, | |
buttons, | |
create_prediction_df: callable, | |
create_molecule_grid_plot: callable): | |
tab1, tab2, tab3, tab4 = st.tabs(["Predictions", | |
"Paper / Cite", | |
"Additional Information", | |
"Examples"]) | |
# Results tab | |
with tab1: | |
self._fill_tab_with_results_content(predictor, | |
inputs, | |
buttons, | |
create_prediction_df, | |
create_molecule_grid_plot) | |
# Paper tab | |
with tab2: | |
self._fill_paper_and_citation_tab() | |
# More explanations tab | |
with tab3: | |
self._fill_more_explanations_tab() | |
with tab4: | |
self._fill_examples_tab() | |
def _make_query_box(self): | |
""" | |
This function | |
a) defines the query box and | |
b) stores the query input in the inputs dictionary | |
""" | |
st.info(":blue[Molecules to predict:]", icon="β") | |
query_container = st.container() | |
with query_container: | |
input_choice = st.radio( | |
"Input your data in SMILES notation via:", ["Text box", "CSV upload"] | |
) | |
if input_choice == "Text box": | |
query_input = st.text_area( | |
label="SMILES input for query molecules", | |
label_visibility="hidden", | |
key="query_textbox", | |
value="CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, " | |
"Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O", | |
) | |
elif input_choice == "CSV upload": | |
query_file = st.file_uploader(key="query_csv", | |
label = "CSV upload for query mols", | |
label_visibility="hidden") | |
if query_file is not None: | |
query_input = pd.read_csv(query_file) | |
else: query_input = None | |
# Update storage | |
self.inputs["query"] = query_input | |
def _make_active_support_set_box(self): | |
""" | |
This function | |
a) defines the active support set box and | |
b) stores the active support set input in the inputs dictionary | |
""" | |
st.info(":blue[Known active molecules:]", icon="β¨") | |
active_container = st.container() | |
with active_container: | |
active_input_choice = st.radio( | |
"Input your data in SMILES notation via:", | |
["Text box", "CSV upload"], | |
key="active_input_choice", | |
) | |
if active_input_choice == "Text box": | |
support_active_input = st.text_area( | |
label="SMILES input for active support set molecules", | |
label_visibility="hidden", | |
key="active_textbox", | |
value="Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O, " | |
"CSc1nc(C(C)C)nc(OCC(=O)O)c1C#N" | |
) | |
elif active_input_choice == "CSV upload": | |
support_active_file = st.file_uploader( | |
key="support_active_csv", | |
label = "CSV upload for active support set molecules", | |
label_visibility="hidden" | |
) | |
if support_active_file is not None: | |
support_active_input = pd.read_csv(support_active_file) | |
else: support_active_input = None | |
# Update storage | |
self.inputs["support_active"] = support_active_input | |
def _make_inactive_support_set_box(self): | |
st.info(":blue[Known inactive molecules:]", icon="β¨") | |
inactive_container = st.container() | |
with inactive_container: | |
inactive_input_choice = st.radio( | |
"Input your data in SMILES notation via:", | |
["Text box", "CSV upload"], | |
key="inactive_input_choice", | |
) | |
if inactive_input_choice == "Text box": | |
support_inactive_input = st.text_area( | |
label="SMILES input for inactive support set molecules", | |
label_visibility="hidden", | |
key="inactive_textbox", | |
value="CSc1nc(C)nc(OCC(=O)O)c1C#N, " | |
"CSc1nc(C)n(CC(=O)O)c(=O)c1C#N" | |
) | |
elif inactive_input_choice == "CSV upload": | |
support_inactive_file = st.file_uploader( | |
key="support_inactive_csv", | |
label = "CSV upload for inactive support set molecules", | |
label_visibility="hidden" | |
) | |
if support_inactive_file is not None: | |
support_inactive_input = pd.read_csv( | |
support_inactive_file | |
) | |
else: support_inactive_input = None | |
# Update storage | |
self.inputs["support_inactive"] = support_inactive_input | |
def _fill_tab_with_results_content(self, predictor, inputs, buttons, | |
create_prediction_df, create_molecule_grid_plot): | |
tab_container = st.container() | |
with tab_container: | |
# Info | |
st.info(":blue[Summary:]", icon="π") | |
st.markdown(self.summary_text) | |
# Results | |
st.info(":blue[Results:]",icon="π¨βπ»") | |
if buttons['predict']: | |
# Check 1: Are all inputs provided? | |
if (inputs['query'] is None or | |
inputs['support_active'] is None or | |
inputs['support_inactive'] is None): | |
st.error("You didn't provide all necessary inputs.\n\n" | |
"Please provide all three necessary inputs via the " | |
"sidebar and hit the predict button again.") | |
else: | |
# Check 2: Less than max allowed molecules provided? | |
max_input_length = 0 | |
for key, input in inputs.items(): | |
input_list = handle_inputs(input) | |
self.inputs_lists[key] = input_list | |
max_input_length = max(max_input_length, len(input_list)) | |
if max_input_length > self.max_input_length: | |
st.error("You provided too many molecules. The number of " | |
"molecules for each input is restricted to " | |
f"{self.max_input_length}.\n\n" | |
"For larger screenings, we suggest to clone the repo " | |
"and to run the model locally.") | |
else: | |
# Progress bar | |
progress_bar_text = ("I'm predicting activities. This might " | |
"need some minutes. Please wait...") | |
progress_bar = st.progress(50, text=progress_bar_text) | |
# Results table | |
df = self._predict_and_create_results_table(predictor, | |
inputs, | |
create_prediction_df) | |
progress_bar_text = ("Done. Here are the results:") | |
progress_bar = progress_bar.progress(100, text=progress_bar_text) | |
st.dataframe(df, use_container_width=True) | |
col1, col2, col3, col4 = st.columns([1,1,1,1]) | |
# Provide download button for predictions | |
with col2: | |
self.buttons["download_results"] = st.download_button( | |
"Download predictions as CSV", | |
self._convert_df_to_binary(df), | |
file_name="predictions.csv", | |
) | |
# Provide download button for inputs | |
with col3: | |
with open("inputs.yml", 'w') as fl: | |
self.buttons["download_inputs"] = st.download_button( | |
"Download inputs as YML", | |
self._convert_to_yml(self.inputs_lists), | |
file_name="inputs.yml", | |
) | |
st.divider() | |
# Results grid | |
st.info(":blue[Grid plot of the predicted molecules:]", | |
icon="π") | |
mol_html_grid = create_molecule_grid_plot(df) | |
components.html(mol_html_grid, height=1000, scrolling=True) | |
elif buttons['reset']: | |
self._reset() | |
def _fill_paper_and_citation_tab(self): | |
st.info(":blue[**Paper: Context-enriched molecule representations improve " | |
"few-shot drug discovery**]", icon="π") | |
st.markdown(self.mhnfs_text, unsafe_allow_html=True) | |
st.image("./assets/mhnfs_overview.png") | |
st.write("") | |
st.write("") | |
st.write("") | |
st.info(":blue[**Cite us / BibTex**]", icon="π") | |
st.markdown(self.citation_text) | |
def _fill_more_explanations_tab(self): | |
st.info(":blue[**Under the hood**]", icon="βοΈ") | |
st.markdown(self.under_the_hood_text, unsafe_allow_html=True) | |
st.write("") | |
st.write("") | |
st.info(":blue[**About few-shot learning and the model MHNfs**]", icon="π―") | |
st.markdown(self.few_shot_learning_text, unsafe_allow_html=True) | |
st.write("") | |
st.write("") | |
st.info(":blue[**Usage**]", icon="ποΈ") | |
st.markdown(self.usage_text, unsafe_allow_html=True) | |
st.write("") | |
st.write("") | |
st.info(":blue[**How to provide the data**]", icon="π") | |
st.markdown(self.data_text, unsafe_allow_html=True) | |
st.write("") | |
st.write("") | |
st.info(":blue[**When to trust the predictions**]", icon="π") | |
st.markdown(self.trust_text, unsafe_allow_html=True) | |
def _fill_examples_tab(self): | |
st.info(":blue[**Example for trustworthy predictions**]", icon="β ") | |
st.markdown(self.example_trustworthy_text, unsafe_allow_html=True) | |
st.dataframe(self.df_trustworthy, use_container_width=True) | |
st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" | |
"0.96**)") | |
prediction_plot_tw = Image.open("./assets/example_csv/predictions/" | |
"trustworthy_example.png") | |
st.image(prediction_plot_tw) | |
st.write("") | |
st.write("") | |
st.info(":blue[**Example for not trustworthy predictions**]", icon="βοΈ") | |
st.markdown(self.example_nottrustworthy_text, unsafe_allow_html=True) | |
st.dataframe(self.df_nottrustworthy, use_container_width=True) | |
st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" | |
"0.42**)") | |
prediction_plot_ntw = Image.open("./assets/example_csv/predictions/" | |
"nottrustworthy_example.png") | |
st.image(prediction_plot_ntw) | |
def _predict_and_create_results_table(self, | |
predictor, | |
inputs, | |
create_prediction_df: callable): | |
df = create_prediction_df(predictor, | |
inputs['query'], | |
inputs['support_active'], | |
inputs['support_inactive']) | |
return df | |
def _reset(self): | |
keys = list(st.session_state.keys()) | |
for key in keys: | |
st.session_state.pop(key) | |
def _convert_df_to_binary(_self, df): | |
return df.to_csv(index=False).encode('utf-8') | |
def _convert_to_yml(_self, inputs): | |
return yaml.dump(inputs) | |
content = """ | |
# Usage | |
As soon as you have a few active and inactive molecules for your task, you can | |
provide them here and make predictions for new molecules. | |
## About few-shot learning and the model MHNfs | |
**Few-shot learning** is a machine learning sub-field which aims to provide | |
predictive models for scenarios in which only little data is known/available. | |
**MHNfs** is a few-shot learning model which is specifically designed for drug | |
discovery applications. It is built to use the input prompts in a way such that | |
the provided available knowledge - i.e. the known active and inactive molecules - | |
functions as context to predict the activity of the new requested molecules. | |
Precisely, the provided active and inactive molecules are associated with a | |
large set of general molecules - called context molecules - to enrich the | |
provided information and to remove spurious correlations arising from the | |
decoration of molecules. This is analogous to a Large Language Model which would | |
not only use the provided information in the current prompt as context but would | |
also have access to way more information, e.g. a prompting history. | |
## How to provide the data | |
* Molecules have to be provided in SMILES format. | |
* You can provide the molecules via the text boxes or via CSV upload. | |
- Text box: Replace the pseudo input by directly typing your molecules into | |
the text box. Please separate the molecules by comma. | |
- CSV upload: Upload a CSV file with the molecules. | |
* The CSV file should include a smiles column (both upper and lower | |
case "SMILES" are accepted). | |
* All other columns will be ignored. | |
## When to trust the predictions | |
Just like all other machine learning models, the performance of MHNfs varies | |
and, generally, the model works well if the task is somehow close to tasks which | |
were used to train the model. The model performance for very different tasks is | |
unclear and might be poor. | |
MHNfs was trained on a the FS-Mol dataset which includes 5120 tasks (Roughly | |
5000 tasks were used for training, rest for evaluation). The training tasks are | |
listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. | |
""" | |
return content |