gabcares's picture
Streamlit client source code
bfc7b8a verified
raw
history blame
18.6 kB
import os
import time
import httpx
import string
import random
import datetime as dt
from dotenv import load_dotenv
import streamlit as st
import extra_streamlit_components as stx
import asyncio
from aiocache import cached, Cache
import pandas as pd
from typing import Optional, Callable
from config import ENV_PATH, BEST_MODELS, TEST_FILE, TEST_FILE_URL, HISTORY_FILE, markdown_table_all
from utils.navigation import navigation
from utils.footer import footer
from utils.janitor import Janitor
# Load ENV
load_dotenv(ENV_PATH) # API_URL
# Set page configuration
st.set_page_config(
page_title="Homepage",
page_icon="๐Ÿค–",
layout="wide",
initial_sidebar_state='auto'
)
@cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_savedataset')
# @st.cache_data(show_spinner="Saving datasets...") # Streamlit cache is yet to support async functions
async def save_dataset(df: pd.DataFrame, filepath, csv=True) -> None:
async def save(df: pd.DataFrame, file):
return df.to_csv(file, index=False) if csv else df.to_excel(file, index=False)
async def read(file):
return pd.read_csv(file) if csv else pd.read_excel(file)
async def same_dfs(df: pd.DataFrame, df2: pd.DataFrame):
return df.equals(df2)
if not os.path.isfile(filepath): # Save if file does not exists
await save(df, filepath)
else: # Save if data are not same
df_old = await read(filepath)
if not await same_dfs(df, df_old):
await save(df, filepath)
@cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_testdata')
async def get_test_data():
try:
df_test_raw = pd.read_csv(TEST_FILE_URL)
await save_dataset(df_test_raw, TEST_FILE, csv=True)
except Exception:
df_test_raw = pd.read_csv(TEST_FILE)
# Some house keeping, clean df
df_test = df_test_raw.copy()
janitor = Janitor()
df_test = janitor.clean_dataframe(df_test) # Cleaned
return df_test_raw, df_test
# Function for selecting models
async def select_model() -> str:
col1, _ = st.columns(2)
with col1:
selected_model = st.selectbox(
'Select a model', options=BEST_MODELS, key='selected_model')
return selected_model
async def endpoint(model: str) -> str:
api_url = os.getenv("API_URL")
model_endpoint = f"{api_url}={model}"
return model_endpoint
# Function for making prediction
async def make_prediction(model_endpoint) -> Optional[pd.DataFrame]:
test_data = await get_test_data()
_, df_test = test_data
df: pd.DataFrame = None
search_patient = st.session_state.get('search_patient', False)
search_patient_id = st.session_state.get('search_patient_id', False)
manual_patient_id = st.session_state.get('manual_patient_id', False)
if isinstance(search_patient_id, str) and search_patient_id: # And not empty string
search_patient_id = [search_patient_id]
if search_patient and search_patient_id: # Search Form df and a patient was selected
mask = df_test['id'].isin(search_patient_id)
df_form = df_test[mask]
df = df_form.copy()
elif not (search_patient or search_patient_id) and manual_patient_id: # Manual form df
columns = ['manual_patient_id', 'prg', 'pl', 'pr', 'sk',
'ts', 'm11', 'bd2', 'age', 'insurance']
data = {c: [st.session_state.get(c)] for c in columns}
data['insurance'] = [1 if i == 'Yes' else 0 for i in data['insurance']]
# Make a DataFrame
df = pd.DataFrame(data).rename(
columns={'manual_patient_id': 'id'})
columns_int = ['prg', 'pl', 'pr', 'sk', 'ts', 'age']
columns_float = ['m11', 'bd2']
df[columns_int] = df[columns_int].astype(int)
df[columns_float] = df[columns_float].astype(float)
else: # Form did not send a patient
message = 'You must choose valid patient(s) from the select box.'
icon = '๐Ÿ˜ž'
st.toast(message, icon=icon)
st.warning(message, icon=icon)
if df is not None:
try:
# JSON data
data = df.to_dict(orient='list')
# Send POST request with JSON data using the json parameter
async with httpx.AsyncClient() as client:
response = await client.post(model_endpoint, json=data, timeout=30)
response.raise_for_status() # Ensure we catch any HTTP errors
if (response.status_code == 200):
pred_prob = (response.json()['result'])
prediction = pred_prob['prediction'][0]
probability = pred_prob['probability'][0]
# Store results in session state
st.session_state['prediction'] = prediction
st.session_state['probability'] = probability
df['prediction'] = prediction
df['probability (%)'] = probability
df['time_of_prediction'] = pd.Timestamp(dt.datetime.now())
df['model_used'] = st.session_state['selected_model']
df.to_csv(HISTORY_FILE, mode='a',
header=not os.path.isfile(HISTORY_FILE))
except Exception as e:
st.error(f'๐Ÿ˜ž Unable to connect to the API server. {e}')
return df
async def convert_string(df: pd.DataFrame, string: str) -> str:
return string.upper() if all(col.isupper() for col in df.columns) else string
async def make_predictions(model_endpoint, df_uploaded=None, df_uploaded_clean=None) -> Optional[pd.DataFrame]:
df: pd.DataFrame = None
search_patient = st.session_state.get('search_patient', False)
patient_id_bulk = st.session_state.get('patient_id_bulk', False)
upload_bulk_predict = st.session_state.get('upload_bulk_predict', False)
if search_patient and patient_id_bulk: # Search Form df and a patient was selected
_, df_test = await get_test_data()
mask = df_test['id'].isin(patient_id_bulk)
df_bulk: pd.DataFrame = df_test[mask]
df = df_bulk.copy()
elif not (search_patient or patient_id_bulk) and upload_bulk_predict: # Upload widget df
df = df_uploaded_clean.copy()
else: # Form did not send a patient
message = 'You must choose valid patient(s) from the select box.'
icon = '๐Ÿ˜ž'
st.toast(message, icon=icon)
st.warning(message, icon=icon)
if df is not None: # df should be set by form input or upload widget
try:
# JSON data
data = df.to_dict(orient='list')
# Send POST request with JSON data using the json parameter
async with httpx.AsyncClient() as client:
response = await client.post(model_endpoint, json=data, timeout=30)
response.raise_for_status() # Ensure we catch any HTTP errors
if (response.status_code == 200):
pred_prob = (response.json()['result'])
predictions = pred_prob['prediction']
probabilities = pred_prob['probability']
# Add columns sepsis, probability, time, and model used to uploaded df and form df
async def add_columns(df):
df[await convert_string(df, 'sepsis')] = predictions
df[await convert_string(df, 'probability_(%)')] = probabilities
df[await convert_string(df, 'time_of_prediction')
] = pd.Timestamp(dt.datetime.now())
df[await convert_string(df, 'model_used')
] = st.session_state['selected_model']
return df
# Form df if search patient is true or df from Uploaded data
if search_patient:
df = await add_columns(df)
df.to_csv(HISTORY_FILE, mode='a', header=not os.path.isfile(
HISTORY_FILE)) # Save only known patients
else:
df = await add_columns(df_uploaded) # Raw, No cleaning
# Store df with prediction results in session state
st.session_state['bulk_prediction_df'] = df
except Exception as e:
st.error(f'๐Ÿ˜ž Unable to connect to the API server. {e}')
return df
def on_click(func: Callable, model_endpoint: str):
async def handle_click():
await func(model_endpoint)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(handle_click())
loop.close()
async def search_patient_form(model_endpoint: str) -> None:
test_data = await get_test_data()
_, df_test = test_data
patient_ids = df_test['id'].unique().tolist()+['']
if st.session_state['sidebar'] == 'single_prediction':
with st.form('search_patient_id_form'):
col1, _ = st.columns(2)
with col1:
st.write('#### Patient ID ๐Ÿค’')
st.selectbox(
'Search a patient', options=patient_ids, index=len(patient_ids)-1, key='search_patient_id')
st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
func=make_prediction, model_endpoint=model_endpoint))
else:
with st.form('search_patient_id_bulk_form'):
col1, _ = st.columns(2)
with col1:
st.write('#### Patient ID ๐Ÿค’')
st.multiselect(
'Search a patient', options=patient_ids, default=None, key='patient_id_bulk')
st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
func=make_predictions, model_endpoint=model_endpoint))
async def gen_random_patient_id() -> str:
numbers = ''.join(random.choices(string.digits, k=6))
letters = ''.join(random.choices(string.ascii_lowercase, k=4))
return f"ICU{numbers}-gen-{letters}"
async def manual_patient_form(model_endpoint) -> None:
with st.form('manual_patient_form'):
col1, col2, col3 = st.columns(3)
with col1:
st.write('### Patient Demographics ๐Ÿ›Œ')
st.text_input(
'ID', value=await gen_random_patient_id(), key='manual_patient_id')
st.number_input('Age: patients age (years)', min_value=0,
max_value=100, step=1, key='age')
st.selectbox('Insurance: If a patient holds a valid insurance card', options=[
'Yes', 'No'], key='insurance')
with col2:
st.write('### Vital Signs ๐Ÿฉบ')
st.number_input('BMI (weight in kg/(height in m)^2', min_value=10.0,
format="%.2f", step=1.00, key='m11')
st.number_input(
'Blood Pressure (mm Hg)', min_value=10.0, format="%.2f", step=1.00, key='pr')
st.number_input(
'PRG (plasma glucose)', min_value=10.0, format="%.2f", step=1.00, key='prg')
with col3:
st.write('### Blood Work ๐Ÿ’‰')
st.number_input(
'PL: Blood Work Result-1 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='pl')
st.number_input(
'SK: Blood Work Result 2 (mm)', min_value=10.0, format="%.2f", step=1.00, key='sk')
st.number_input(
'TS: Blood Work Result-3 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='ts')
st.number_input(
'BD2: Blood Work Result-4 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='bd2')
st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
func=make_prediction, model_endpoint=model_endpoint))
async def do_single_prediction(model_endpoint: str) -> None:
if st.session_state.get('search_patient', False):
await search_patient_form(model_endpoint)
else:
await manual_patient_form(model_endpoint)
async def show_prediction() -> None:
final_prediction = st.session_state.get('prediction', None)
final_probability = st.session_state.get('probability', None)
if final_prediction is None:
st.markdown('#### Prediction will show below! ๐Ÿ”ฌ')
st.divider()
else:
st.markdown('#### Prediction! ๐Ÿ”ฌ')
st.divider()
if final_prediction.lower() == 'positive':
st.toast("Sepsis alert!", icon='๐Ÿฆ ')
message = f"It is **{final_probability:.2f} %** likely that the patient will develop **sepsis.**"
st.warning(message, icon='๐Ÿ˜ž')
time.sleep(5)
st.toast(message)
else:
st.toast("Continous monitoring", icon='๐Ÿ”ฌ')
message = f"The patient will **not** develop sepsis with a likelihood of **{final_probability:.2f}%**."
st.success(message, icon='๐Ÿ˜Š')
time.sleep(5)
st.toast(message)
# Set prediction and probability to None
st.session_state['prediction'] = None
st.session_state['probability'] = None
# @st.cache_data(show_spinner=False) Caching results from async functions buggy
async def convert_df(df: pd.DataFrame):
return df.to_csv(index=False)
async def bulk_upload_widget(model_endpoint: str) -> None:
uploaded_file = st.file_uploader(
"Choose a CSV or Excel File", type=['csv', 'xls', 'xlsx'])
uploaded = uploaded_file is not None
upload_bulk_predict = st.button('Predict', type='primary',
help='Upload a csv/excel file to make predictions', disabled=not uploaded, key='upload_bulk_predict')
df = None
if upload_bulk_predict and uploaded:
df_test_raw, _ = await get_test_data()
# Uploadfile is a "file-like" object is accepted
try:
try:
df = pd.read_csv(uploaded_file)
except Exception:
df = pd.read_excel(uploaded_file)
df_columns = set(df.columns)
df_test_columns = set(df_test_raw.columns)
df_schema = df.dtypes
df_test_schema = df_test_raw.dtypes
if df_columns != df_test_columns or not df_schema.equals(df_test_schema):
df = None
raise Exception
else:
# Clean dataframe
janitor = Janitor()
df_clean = janitor.clean_dataframe(df)
df = await make_predictions(
model_endpoint, df_uploaded=df, df_uploaded_clean=df_clean)
except Exception:
st.subheader('Data template')
data_template = df_test_raw[:3]
st.dataframe(data_template)
csv = await convert_df(data_template)
message_1 = 'Upload a valid csv or excel file.'
message_2 = f"{message_1.split('.')[0]} with the columns and schema of the above data template."
icon = '๐Ÿ˜ž'
st.toast(message_1, icon=icon)
st.download_button(
label='Download template',
data=csv,
file_name='Data template.csv',
mime="text/csv",
type='secondary',
key='download-data-template'
)
st.info('Download the above template for use as a baseline structure.')
# Display explander to show the data dictionary
with st.expander("Expand to see the data dictionary", icon="๐Ÿ’ก"):
st.subheader("Data dictionary")
st.markdown(markdown_table_all)
st.warning(message_2, icon=icon)
return df
async def do_bulk_prediction(model_endpoint: str) -> None:
if st.session_state.get('search_patient', False):
await search_patient_form(model_endpoint)
else:
# File uploader
await bulk_upload_widget(model_endpoint)
async def show_bulk_predictions(df: pd.DataFrame) -> None:
if df is not None:
st.subheader("Bulk predictions ๐Ÿ”ฎ", divider=True)
st.dataframe(df.astype(str))
csv = await convert_df(df)
message = 'The predictions are ready for download.'
icon = 'โฌ‡๏ธ'
st.toast(message, icon=icon)
st.info(message, icon=icon)
st.download_button(
label='Download predictions',
data=csv,
file_name='Bulk prediction.csv',
mime="text/csv",
type='secondary',
key='download-bulk-prediction'
)
# Set bulk prediction df to None
st.session_state['bulk_prediction_df'] = None
async def sidebar(sidebar_type: str) -> st.sidebar:
return st.session_state.update({'sidebar': sidebar_type})
async def main():
st.title("๐Ÿค– Predict Sepsis ๐Ÿฆ ")
# Navigation
await navigation()
st.sidebar.toggle("Looking for a patient?", value=st.session_state.get(
'search_patient', False), key='search_patient')
selected_model = await select_model()
model_endpoint = await endpoint(selected_model)
selected_predict_tab = st.session_state.get('selected_predict_tab')
default = 1 if selected_predict_tab is None else selected_predict_tab
with st.spinner('A little house keeping...'):
time.sleep(st.session_state.get('sleep', 1.5))
chosen_id = stx.tab_bar(data=[
stx.TabBarItemData(id=1, title='๐Ÿ”ฌ Predict', description=''),
stx.TabBarItemData(id=2, title='๐Ÿ”ฎ Bulk predict',
description=''),
], default=default)
st.session_state['sleep'] = 0
if chosen_id == '1':
await sidebar('single_prediction')
await do_single_prediction(model_endpoint)
await show_prediction()
elif chosen_id == '2':
await sidebar('bulk_prediction')
df_with_predictions = await do_bulk_prediction(model_endpoint)
if df_with_predictions is None:
df_with_predictions = st.session_state.get(
'bulk_prediction_df', None)
await show_bulk_predictions(df_with_predictions)
# Add footer
await footer()
if __name__ == "__main__":
asyncio.run(main())