Spaces:
Sleeping
Sleeping
import os | |
import time | |
import httpx | |
import string | |
import random | |
import datetime as dt | |
from dotenv import load_dotenv | |
import streamlit as st | |
import extra_streamlit_components as stx | |
import asyncio | |
from aiocache import cached, Cache | |
import pandas as pd | |
from typing import Optional, Callable | |
from config import ENV_PATH, BEST_MODELS, TEST_FILE, TEST_FILE_URL, HISTORY_FILE, markdown_table_all | |
from utils.navigation import navigation | |
from utils.footer import footer | |
from utils.janitor import Janitor | |
# Load ENV | |
load_dotenv(ENV_PATH) # API_URL | |
# Set page configuration | |
st.set_page_config( | |
page_title="Homepage", | |
page_icon="๐ค", | |
layout="wide", | |
initial_sidebar_state='auto' | |
) | |
# @st.cache_data(show_spinner="Saving datasets...") # Streamlit cache is yet to support async functions | |
async def save_dataset(df: pd.DataFrame, filepath, csv=True) -> None: | |
async def save(df: pd.DataFrame, file): | |
return df.to_csv(file, index=False) if csv else df.to_excel(file, index=False) | |
async def read(file): | |
return pd.read_csv(file) if csv else pd.read_excel(file) | |
async def same_dfs(df: pd.DataFrame, df2: pd.DataFrame): | |
return df.equals(df2) | |
if not os.path.isfile(filepath): # Save if file does not exists | |
await save(df, filepath) | |
else: # Save if data are not same | |
df_old = await read(filepath) | |
if not await same_dfs(df, df_old): | |
await save(df, filepath) | |
async def get_test_data(): | |
try: | |
df_test_raw = pd.read_csv(TEST_FILE_URL) | |
await save_dataset(df_test_raw, TEST_FILE, csv=True) | |
except Exception: | |
df_test_raw = pd.read_csv(TEST_FILE) | |
# Some house keeping, clean df | |
df_test = df_test_raw.copy() | |
janitor = Janitor() | |
df_test = janitor.clean_dataframe(df_test) # Cleaned | |
return df_test_raw, df_test | |
# Function for selecting models | |
async def select_model() -> str: | |
col1, _ = st.columns(2) | |
with col1: | |
selected_model = st.selectbox( | |
'Select a model', options=BEST_MODELS, key='selected_model') | |
return selected_model | |
async def endpoint(model: str) -> str: | |
api_url = os.getenv("API_URL") | |
model_endpoint = f"{api_url}={model}" | |
return model_endpoint | |
# Function for making prediction | |
async def make_prediction(model_endpoint) -> Optional[pd.DataFrame]: | |
test_data = await get_test_data() | |
_, df_test = test_data | |
df: pd.DataFrame = None | |
search_patient = st.session_state.get('search_patient', False) | |
search_patient_id = st.session_state.get('search_patient_id', False) | |
manual_patient_id = st.session_state.get('manual_patient_id', False) | |
if isinstance(search_patient_id, str) and search_patient_id: # And not empty string | |
search_patient_id = [search_patient_id] | |
if search_patient and search_patient_id: # Search Form df and a patient was selected | |
mask = df_test['id'].isin(search_patient_id) | |
df_form = df_test[mask] | |
df = df_form.copy() | |
elif not (search_patient or search_patient_id) and manual_patient_id: # Manual form df | |
columns = ['manual_patient_id', 'prg', 'pl', 'pr', 'sk', | |
'ts', 'm11', 'bd2', 'age', 'insurance'] | |
data = {c: [st.session_state.get(c)] for c in columns} | |
data['insurance'] = [1 if i == 'Yes' else 0 for i in data['insurance']] | |
# Make a DataFrame | |
df = pd.DataFrame(data).rename( | |
columns={'manual_patient_id': 'id'}) | |
columns_int = ['prg', 'pl', 'pr', 'sk', 'ts', 'age'] | |
columns_float = ['m11', 'bd2'] | |
df[columns_int] = df[columns_int].astype(int) | |
df[columns_float] = df[columns_float].astype(float) | |
else: # Form did not send a patient | |
message = 'You must choose valid patient(s) from the select box.' | |
icon = '๐' | |
st.toast(message, icon=icon) | |
st.warning(message, icon=icon) | |
if df is not None: | |
try: | |
# JSON data | |
data = df.to_dict(orient='list') | |
# Send POST request with JSON data using the json parameter | |
async with httpx.AsyncClient() as client: | |
response = await client.post(model_endpoint, json=data, timeout=30) | |
response.raise_for_status() # Ensure we catch any HTTP errors | |
if (response.status_code == 200): | |
pred_prob = (response.json()['result']) | |
prediction = pred_prob['prediction'][0] | |
probability = pred_prob['probability'][0] | |
# Store results in session state | |
st.session_state['prediction'] = prediction | |
st.session_state['probability'] = probability | |
df['prediction'] = prediction | |
df['probability (%)'] = probability | |
df['time_of_prediction'] = pd.Timestamp(dt.datetime.now()) | |
df['model_used'] = st.session_state['selected_model'] | |
df.to_csv(HISTORY_FILE, mode='a', | |
header=not os.path.isfile(HISTORY_FILE)) | |
except Exception as e: | |
st.error(f'๐ Unable to connect to the API server. {e}') | |
return df | |
async def convert_string(df: pd.DataFrame, string: str) -> str: | |
return string.upper() if all(col.isupper() for col in df.columns) else string | |
async def make_predictions(model_endpoint, df_uploaded=None, df_uploaded_clean=None) -> Optional[pd.DataFrame]: | |
df: pd.DataFrame = None | |
search_patient = st.session_state.get('search_patient', False) | |
patient_id_bulk = st.session_state.get('patient_id_bulk', False) | |
upload_bulk_predict = st.session_state.get('upload_bulk_predict', False) | |
if search_patient and patient_id_bulk: # Search Form df and a patient was selected | |
_, df_test = await get_test_data() | |
mask = df_test['id'].isin(patient_id_bulk) | |
df_bulk: pd.DataFrame = df_test[mask] | |
df = df_bulk.copy() | |
elif not (search_patient or patient_id_bulk) and upload_bulk_predict: # Upload widget df | |
df = df_uploaded_clean.copy() | |
else: # Form did not send a patient | |
message = 'You must choose valid patient(s) from the select box.' | |
icon = '๐' | |
st.toast(message, icon=icon) | |
st.warning(message, icon=icon) | |
if df is not None: # df should be set by form input or upload widget | |
try: | |
# JSON data | |
data = df.to_dict(orient='list') | |
# Send POST request with JSON data using the json parameter | |
async with httpx.AsyncClient() as client: | |
response = await client.post(model_endpoint, json=data, timeout=30) | |
response.raise_for_status() # Ensure we catch any HTTP errors | |
if (response.status_code == 200): | |
pred_prob = (response.json()['result']) | |
predictions = pred_prob['prediction'] | |
probabilities = pred_prob['probability'] | |
# Add columns sepsis, probability, time, and model used to uploaded df and form df | |
async def add_columns(df): | |
df[await convert_string(df, 'sepsis')] = predictions | |
df[await convert_string(df, 'probability_(%)')] = probabilities | |
df[await convert_string(df, 'time_of_prediction') | |
] = pd.Timestamp(dt.datetime.now()) | |
df[await convert_string(df, 'model_used') | |
] = st.session_state['selected_model'] | |
return df | |
# Form df if search patient is true or df from Uploaded data | |
if search_patient: | |
df = await add_columns(df) | |
df.to_csv(HISTORY_FILE, mode='a', header=not os.path.isfile( | |
HISTORY_FILE)) # Save only known patients | |
else: | |
df = await add_columns(df_uploaded) # Raw, No cleaning | |
# Store df with prediction results in session state | |
st.session_state['bulk_prediction_df'] = df | |
except Exception as e: | |
st.error(f'๐ Unable to connect to the API server. {e}') | |
return df | |
def on_click(func: Callable, model_endpoint: str): | |
async def handle_click(): | |
await func(model_endpoint) | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(handle_click()) | |
loop.close() | |
async def search_patient_form(model_endpoint: str) -> None: | |
test_data = await get_test_data() | |
_, df_test = test_data | |
patient_ids = df_test['id'].unique().tolist()+[''] | |
if st.session_state['sidebar'] == 'single_prediction': | |
with st.form('search_patient_id_form'): | |
col1, _ = st.columns(2) | |
with col1: | |
st.write('#### Patient ID ๐ค') | |
st.selectbox( | |
'Search a patient', options=patient_ids, index=len(patient_ids)-1, key='search_patient_id') | |
st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict( | |
func=make_prediction, model_endpoint=model_endpoint)) | |
else: | |
with st.form('search_patient_id_bulk_form'): | |
col1, _ = st.columns(2) | |
with col1: | |
st.write('#### Patient ID ๐ค') | |
st.multiselect( | |
'Search a patient', options=patient_ids, default=None, key='patient_id_bulk') | |
st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict( | |
func=make_predictions, model_endpoint=model_endpoint)) | |
async def gen_random_patient_id() -> str: | |
numbers = ''.join(random.choices(string.digits, k=6)) | |
letters = ''.join(random.choices(string.ascii_lowercase, k=4)) | |
return f"ICU{numbers}-gen-{letters}" | |
async def manual_patient_form(model_endpoint) -> None: | |
with st.form('manual_patient_form'): | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.write('### Patient Demographics ๐') | |
st.text_input( | |
'ID', value=await gen_random_patient_id(), key='manual_patient_id') | |
st.number_input('Age: patients age (years)', min_value=0, | |
max_value=100, step=1, key='age') | |
st.selectbox('Insurance: If a patient holds a valid insurance card', options=[ | |
'Yes', 'No'], key='insurance') | |
with col2: | |
st.write('### Vital Signs ๐ฉบ') | |
st.number_input('BMI (weight in kg/(height in m)^2', min_value=10.0, | |
format="%.2f", step=1.00, key='m11') | |
st.number_input( | |
'Blood Pressure (mm Hg)', min_value=10.0, format="%.2f", step=1.00, key='pr') | |
st.number_input( | |
'PRG (plasma glucose)', min_value=10.0, format="%.2f", step=1.00, key='prg') | |
with col3: | |
st.write('### Blood Work ๐') | |
st.number_input( | |
'PL: Blood Work Result-1 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='pl') | |
st.number_input( | |
'SK: Blood Work Result 2 (mm)', min_value=10.0, format="%.2f", step=1.00, key='sk') | |
st.number_input( | |
'TS: Blood Work Result-3 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='ts') | |
st.number_input( | |
'BD2: Blood Work Result-4 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='bd2') | |
st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict( | |
func=make_prediction, model_endpoint=model_endpoint)) | |
async def do_single_prediction(model_endpoint: str) -> None: | |
if st.session_state.get('search_patient', False): | |
await search_patient_form(model_endpoint) | |
else: | |
await manual_patient_form(model_endpoint) | |
async def show_prediction() -> None: | |
final_prediction = st.session_state.get('prediction', None) | |
final_probability = st.session_state.get('probability', None) | |
if final_prediction is None: | |
st.markdown('#### Prediction will show below! ๐ฌ') | |
st.divider() | |
else: | |
st.markdown('#### Prediction! ๐ฌ') | |
st.divider() | |
if final_prediction.lower() == 'positive': | |
st.toast("Sepsis alert!", icon='๐ฆ ') | |
message = f"It is **{final_probability:.2f} %** likely that the patient will develop **sepsis.**" | |
st.warning(message, icon='๐') | |
time.sleep(5) | |
st.toast(message) | |
else: | |
st.toast("Continous monitoring", icon='๐ฌ') | |
message = f"The patient will **not** develop sepsis with a likelihood of **{final_probability:.2f}%**." | |
st.success(message, icon='๐') | |
time.sleep(5) | |
st.toast(message) | |
# Set prediction and probability to None | |
st.session_state['prediction'] = None | |
st.session_state['probability'] = None | |
# @st.cache_data(show_spinner=False) Caching results from async functions buggy | |
async def convert_df(df: pd.DataFrame): | |
return df.to_csv(index=False) | |
async def bulk_upload_widget(model_endpoint: str) -> None: | |
uploaded_file = st.file_uploader( | |
"Choose a CSV or Excel File", type=['csv', 'xls', 'xlsx']) | |
uploaded = uploaded_file is not None | |
upload_bulk_predict = st.button('Predict', type='primary', | |
help='Upload a csv/excel file to make predictions', disabled=not uploaded, key='upload_bulk_predict') | |
df = None | |
if upload_bulk_predict and uploaded: | |
df_test_raw, _ = await get_test_data() | |
# Uploadfile is a "file-like" object is accepted | |
try: | |
try: | |
df = pd.read_csv(uploaded_file) | |
except Exception: | |
df = pd.read_excel(uploaded_file) | |
df_columns = set(df.columns) | |
df_test_columns = set(df_test_raw.columns) | |
df_schema = df.dtypes | |
df_test_schema = df_test_raw.dtypes | |
if df_columns != df_test_columns or not df_schema.equals(df_test_schema): | |
df = None | |
raise Exception | |
else: | |
# Clean dataframe | |
janitor = Janitor() | |
df_clean = janitor.clean_dataframe(df) | |
df = await make_predictions( | |
model_endpoint, df_uploaded=df, df_uploaded_clean=df_clean) | |
except Exception: | |
st.subheader('Data template') | |
data_template = df_test_raw[:3] | |
st.dataframe(data_template) | |
csv = await convert_df(data_template) | |
message_1 = 'Upload a valid csv or excel file.' | |
message_2 = f"{message_1.split('.')[0]} with the columns and schema of the above data template." | |
icon = '๐' | |
st.toast(message_1, icon=icon) | |
st.download_button( | |
label='Download template', | |
data=csv, | |
file_name='Data template.csv', | |
mime="text/csv", | |
type='secondary', | |
key='download-data-template' | |
) | |
st.info('Download the above template for use as a baseline structure.') | |
# Display explander to show the data dictionary | |
with st.expander("Expand to see the data dictionary", icon="๐ก"): | |
st.subheader("Data dictionary") | |
st.markdown(markdown_table_all) | |
st.warning(message_2, icon=icon) | |
return df | |
async def do_bulk_prediction(model_endpoint: str) -> None: | |
if st.session_state.get('search_patient', False): | |
await search_patient_form(model_endpoint) | |
else: | |
# File uploader | |
await bulk_upload_widget(model_endpoint) | |
async def show_bulk_predictions(df: pd.DataFrame) -> None: | |
if df is not None: | |
st.subheader("Bulk predictions ๐ฎ", divider=True) | |
st.dataframe(df.astype(str)) | |
csv = await convert_df(df) | |
message = 'The predictions are ready for download.' | |
icon = 'โฌ๏ธ' | |
st.toast(message, icon=icon) | |
st.info(message, icon=icon) | |
st.download_button( | |
label='Download predictions', | |
data=csv, | |
file_name='Bulk prediction.csv', | |
mime="text/csv", | |
type='secondary', | |
key='download-bulk-prediction' | |
) | |
# Set bulk prediction df to None | |
st.session_state['bulk_prediction_df'] = None | |
async def sidebar(sidebar_type: str) -> st.sidebar: | |
return st.session_state.update({'sidebar': sidebar_type}) | |
async def main(): | |
st.title("๐ค Predict Sepsis ๐ฆ ") | |
# Navigation | |
await navigation() | |
st.sidebar.toggle("Looking for a patient?", value=st.session_state.get( | |
'search_patient', False), key='search_patient') | |
selected_model = await select_model() | |
model_endpoint = await endpoint(selected_model) | |
selected_predict_tab = st.session_state.get('selected_predict_tab') | |
default = 1 if selected_predict_tab is None else selected_predict_tab | |
with st.spinner('A little house keeping...'): | |
time.sleep(st.session_state.get('sleep', 1.5)) | |
chosen_id = stx.tab_bar(data=[ | |
stx.TabBarItemData(id=1, title='๐ฌ Predict', description=''), | |
stx.TabBarItemData(id=2, title='๐ฎ Bulk predict', | |
description=''), | |
], default=default) | |
st.session_state['sleep'] = 0 | |
if chosen_id == '1': | |
await sidebar('single_prediction') | |
await do_single_prediction(model_endpoint) | |
await show_prediction() | |
elif chosen_id == '2': | |
await sidebar('bulk_prediction') | |
df_with_predictions = await do_bulk_prediction(model_endpoint) | |
if df_with_predictions is None: | |
df_with_predictions = st.session_state.get( | |
'bulk_prediction_df', None) | |
await show_bulk_predictions(df_with_predictions) | |
# Add footer | |
await footer() | |
if __name__ == "__main__": | |
asyncio.run(main()) | |