gabcares commited on
Commit
e563ea1
ยท
verified ยท
1 Parent(s): bfc7b8a

Update app.py

Browse files

**Update toast display to 1 sec**

Files changed (1) hide show
  1. app.py +484 -484
app.py CHANGED
@@ -1,484 +1,484 @@
1
- import os
2
- import time
3
- import httpx
4
- import string
5
- import random
6
- import datetime as dt
7
- from dotenv import load_dotenv
8
-
9
- import streamlit as st
10
- import extra_streamlit_components as stx
11
-
12
- import asyncio
13
- from aiocache import cached, Cache
14
-
15
- import pandas as pd
16
- from typing import Optional, Callable
17
-
18
- from config import ENV_PATH, BEST_MODELS, TEST_FILE, TEST_FILE_URL, HISTORY_FILE, markdown_table_all
19
-
20
- from utils.navigation import navigation
21
- from utils.footer import footer
22
- from utils.janitor import Janitor
23
-
24
-
25
- # Load ENV
26
- load_dotenv(ENV_PATH) # API_URL
27
-
28
- # Set page configuration
29
- st.set_page_config(
30
- page_title="Homepage",
31
- page_icon="๐Ÿค–",
32
- layout="wide",
33
- initial_sidebar_state='auto'
34
- )
35
-
36
-
37
- @cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_savedataset')
38
- # @st.cache_data(show_spinner="Saving datasets...") # Streamlit cache is yet to support async functions
39
- async def save_dataset(df: pd.DataFrame, filepath, csv=True) -> None:
40
- async def save(df: pd.DataFrame, file):
41
- return df.to_csv(file, index=False) if csv else df.to_excel(file, index=False)
42
-
43
- async def read(file):
44
- return pd.read_csv(file) if csv else pd.read_excel(file)
45
-
46
- async def same_dfs(df: pd.DataFrame, df2: pd.DataFrame):
47
- return df.equals(df2)
48
-
49
- if not os.path.isfile(filepath): # Save if file does not exists
50
- await save(df, filepath)
51
- else: # Save if data are not same
52
- df_old = await read(filepath)
53
- if not await same_dfs(df, df_old):
54
- await save(df, filepath)
55
-
56
-
57
- @cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_testdata')
58
- async def get_test_data():
59
- try:
60
- df_test_raw = pd.read_csv(TEST_FILE_URL)
61
- await save_dataset(df_test_raw, TEST_FILE, csv=True)
62
- except Exception:
63
- df_test_raw = pd.read_csv(TEST_FILE)
64
-
65
- # Some house keeping, clean df
66
- df_test = df_test_raw.copy()
67
- janitor = Janitor()
68
- df_test = janitor.clean_dataframe(df_test) # Cleaned
69
-
70
- return df_test_raw, df_test
71
-
72
-
73
- # Function for selecting models
74
- async def select_model() -> str:
75
- col1, _ = st.columns(2)
76
- with col1:
77
- selected_model = st.selectbox(
78
- 'Select a model', options=BEST_MODELS, key='selected_model')
79
-
80
- return selected_model
81
-
82
-
83
- async def endpoint(model: str) -> str:
84
- api_url = os.getenv("API_URL")
85
- model_endpoint = f"{api_url}={model}"
86
- return model_endpoint
87
-
88
-
89
- # Function for making prediction
90
- async def make_prediction(model_endpoint) -> Optional[pd.DataFrame]:
91
-
92
- test_data = await get_test_data()
93
- _, df_test = test_data
94
-
95
- df: pd.DataFrame = None
96
- search_patient = st.session_state.get('search_patient', False)
97
- search_patient_id = st.session_state.get('search_patient_id', False)
98
- manual_patient_id = st.session_state.get('manual_patient_id', False)
99
- if isinstance(search_patient_id, str) and search_patient_id: # And not empty string
100
- search_patient_id = [search_patient_id]
101
- if search_patient and search_patient_id: # Search Form df and a patient was selected
102
- mask = df_test['id'].isin(search_patient_id)
103
- df_form = df_test[mask]
104
- df = df_form.copy()
105
- elif not (search_patient or search_patient_id) and manual_patient_id: # Manual form df
106
- columns = ['manual_patient_id', 'prg', 'pl', 'pr', 'sk',
107
- 'ts', 'm11', 'bd2', 'age', 'insurance']
108
- data = {c: [st.session_state.get(c)] for c in columns}
109
- data['insurance'] = [1 if i == 'Yes' else 0 for i in data['insurance']]
110
-
111
- # Make a DataFrame
112
- df = pd.DataFrame(data).rename(
113
- columns={'manual_patient_id': 'id'})
114
- columns_int = ['prg', 'pl', 'pr', 'sk', 'ts', 'age']
115
- columns_float = ['m11', 'bd2']
116
-
117
- df[columns_int] = df[columns_int].astype(int)
118
- df[columns_float] = df[columns_float].astype(float)
119
- else: # Form did not send a patient
120
- message = 'You must choose valid patient(s) from the select box.'
121
- icon = '๐Ÿ˜ž'
122
- st.toast(message, icon=icon)
123
- st.warning(message, icon=icon)
124
-
125
- if df is not None:
126
- try:
127
- # JSON data
128
- data = df.to_dict(orient='list')
129
-
130
- # Send POST request with JSON data using the json parameter
131
- async with httpx.AsyncClient() as client:
132
- response = await client.post(model_endpoint, json=data, timeout=30)
133
- response.raise_for_status() # Ensure we catch any HTTP errors
134
-
135
- if (response.status_code == 200):
136
- pred_prob = (response.json()['result'])
137
- prediction = pred_prob['prediction'][0]
138
- probability = pred_prob['probability'][0]
139
-
140
- # Store results in session state
141
- st.session_state['prediction'] = prediction
142
- st.session_state['probability'] = probability
143
- df['prediction'] = prediction
144
- df['probability (%)'] = probability
145
- df['time_of_prediction'] = pd.Timestamp(dt.datetime.now())
146
- df['model_used'] = st.session_state['selected_model']
147
-
148
- df.to_csv(HISTORY_FILE, mode='a',
149
- header=not os.path.isfile(HISTORY_FILE))
150
- except Exception as e:
151
- st.error(f'๐Ÿ˜ž Unable to connect to the API server. {e}')
152
-
153
- return df
154
-
155
-
156
- async def convert_string(df: pd.DataFrame, string: str) -> str:
157
- return string.upper() if all(col.isupper() for col in df.columns) else string
158
-
159
-
160
- async def make_predictions(model_endpoint, df_uploaded=None, df_uploaded_clean=None) -> Optional[pd.DataFrame]:
161
-
162
- df: pd.DataFrame = None
163
- search_patient = st.session_state.get('search_patient', False)
164
- patient_id_bulk = st.session_state.get('patient_id_bulk', False)
165
- upload_bulk_predict = st.session_state.get('upload_bulk_predict', False)
166
- if search_patient and patient_id_bulk: # Search Form df and a patient was selected
167
- _, df_test = await get_test_data()
168
- mask = df_test['id'].isin(patient_id_bulk)
169
- df_bulk: pd.DataFrame = df_test[mask]
170
- df = df_bulk.copy()
171
-
172
- elif not (search_patient or patient_id_bulk) and upload_bulk_predict: # Upload widget df
173
- df = df_uploaded_clean.copy()
174
- else: # Form did not send a patient
175
- message = 'You must choose valid patient(s) from the select box.'
176
- icon = '๐Ÿ˜ž'
177
- st.toast(message, icon=icon)
178
- st.warning(message, icon=icon)
179
-
180
- if df is not None: # df should be set by form input or upload widget
181
- try:
182
- # JSON data
183
- data = df.to_dict(orient='list')
184
-
185
- # Send POST request with JSON data using the json parameter
186
- async with httpx.AsyncClient() as client:
187
- response = await client.post(model_endpoint, json=data, timeout=30)
188
- response.raise_for_status() # Ensure we catch any HTTP errors
189
-
190
- if (response.status_code == 200):
191
- pred_prob = (response.json()['result'])
192
- predictions = pred_prob['prediction']
193
- probabilities = pred_prob['probability']
194
-
195
- # Add columns sepsis, probability, time, and model used to uploaded df and form df
196
-
197
- async def add_columns(df):
198
- df[await convert_string(df, 'sepsis')] = predictions
199
- df[await convert_string(df, 'probability_(%)')] = probabilities
200
- df[await convert_string(df, 'time_of_prediction')
201
- ] = pd.Timestamp(dt.datetime.now())
202
- df[await convert_string(df, 'model_used')
203
- ] = st.session_state['selected_model']
204
-
205
- return df
206
-
207
- # Form df if search patient is true or df from Uploaded data
208
- if search_patient:
209
- df = await add_columns(df)
210
-
211
- df.to_csv(HISTORY_FILE, mode='a', header=not os.path.isfile(
212
- HISTORY_FILE)) # Save only known patients
213
-
214
- else:
215
- df = await add_columns(df_uploaded) # Raw, No cleaning
216
-
217
- # Store df with prediction results in session state
218
- st.session_state['bulk_prediction_df'] = df
219
- except Exception as e:
220
- st.error(f'๐Ÿ˜ž Unable to connect to the API server. {e}')
221
-
222
- return df
223
-
224
-
225
- def on_click(func: Callable, model_endpoint: str):
226
- async def handle_click():
227
- await func(model_endpoint)
228
-
229
- loop = asyncio.new_event_loop()
230
- asyncio.set_event_loop(loop)
231
- loop.run_until_complete(handle_click())
232
- loop.close()
233
-
234
-
235
- async def search_patient_form(model_endpoint: str) -> None:
236
- test_data = await get_test_data()
237
- _, df_test = test_data
238
-
239
- patient_ids = df_test['id'].unique().tolist()+['']
240
- if st.session_state['sidebar'] == 'single_prediction':
241
- with st.form('search_patient_id_form'):
242
- col1, _ = st.columns(2)
243
- with col1:
244
- st.write('#### Patient ID ๐Ÿค’')
245
- st.selectbox(
246
- 'Search a patient', options=patient_ids, index=len(patient_ids)-1, key='search_patient_id')
247
- st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
248
- func=make_prediction, model_endpoint=model_endpoint))
249
- else:
250
- with st.form('search_patient_id_bulk_form'):
251
- col1, _ = st.columns(2)
252
- with col1:
253
- st.write('#### Patient ID ๐Ÿค’')
254
- st.multiselect(
255
- 'Search a patient', options=patient_ids, default=None, key='patient_id_bulk')
256
- st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
257
- func=make_predictions, model_endpoint=model_endpoint))
258
-
259
-
260
- async def gen_random_patient_id() -> str:
261
- numbers = ''.join(random.choices(string.digits, k=6))
262
- letters = ''.join(random.choices(string.ascii_lowercase, k=4))
263
- return f"ICU{numbers}-gen-{letters}"
264
-
265
-
266
- async def manual_patient_form(model_endpoint) -> None:
267
- with st.form('manual_patient_form'):
268
-
269
- col1, col2, col3 = st.columns(3)
270
-
271
- with col1:
272
- st.write('### Patient Demographics ๐Ÿ›Œ')
273
- st.text_input(
274
- 'ID', value=await gen_random_patient_id(), key='manual_patient_id')
275
- st.number_input('Age: patients age (years)', min_value=0,
276
- max_value=100, step=1, key='age')
277
- st.selectbox('Insurance: If a patient holds a valid insurance card', options=[
278
- 'Yes', 'No'], key='insurance')
279
-
280
- with col2:
281
- st.write('### Vital Signs ๐Ÿฉบ')
282
- st.number_input('BMI (weight in kg/(height in m)^2', min_value=10.0,
283
- format="%.2f", step=1.00, key='m11')
284
- st.number_input(
285
- 'Blood Pressure (mm Hg)', min_value=10.0, format="%.2f", step=1.00, key='pr')
286
- st.number_input(
287
- 'PRG (plasma glucose)', min_value=10.0, format="%.2f", step=1.00, key='prg')
288
-
289
- with col3:
290
- st.write('### Blood Work ๐Ÿ’‰')
291
- st.number_input(
292
- 'PL: Blood Work Result-1 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='pl')
293
- st.number_input(
294
- 'SK: Blood Work Result 2 (mm)', min_value=10.0, format="%.2f", step=1.00, key='sk')
295
- st.number_input(
296
- 'TS: Blood Work Result-3 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='ts')
297
- st.number_input(
298
- 'BD2: Blood Work Result-4 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='bd2')
299
-
300
- st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
301
- func=make_prediction, model_endpoint=model_endpoint))
302
-
303
-
304
- async def do_single_prediction(model_endpoint: str) -> None:
305
- if st.session_state.get('search_patient', False):
306
- await search_patient_form(model_endpoint)
307
- else:
308
- await manual_patient_form(model_endpoint)
309
-
310
-
311
- async def show_prediction() -> None:
312
- final_prediction = st.session_state.get('prediction', None)
313
- final_probability = st.session_state.get('probability', None)
314
-
315
- if final_prediction is None:
316
- st.markdown('#### Prediction will show below! ๐Ÿ”ฌ')
317
- st.divider()
318
- else:
319
- st.markdown('#### Prediction! ๐Ÿ”ฌ')
320
- st.divider()
321
- if final_prediction.lower() == 'positive':
322
- st.toast("Sepsis alert!", icon='๐Ÿฆ ')
323
- message = f"It is **{final_probability:.2f} %** likely that the patient will develop **sepsis.**"
324
- st.warning(message, icon='๐Ÿ˜ž')
325
- time.sleep(5)
326
- st.toast(message)
327
- else:
328
- st.toast("Continous monitoring", icon='๐Ÿ”ฌ')
329
- message = f"The patient will **not** develop sepsis with a likelihood of **{final_probability:.2f}%**."
330
- st.success(message, icon='๐Ÿ˜Š')
331
- time.sleep(5)
332
- st.toast(message)
333
-
334
- # Set prediction and probability to None
335
- st.session_state['prediction'] = None
336
- st.session_state['probability'] = None
337
-
338
-
339
- # @st.cache_data(show_spinner=False) Caching results from async functions buggy
340
- async def convert_df(df: pd.DataFrame):
341
- return df.to_csv(index=False)
342
-
343
-
344
- async def bulk_upload_widget(model_endpoint: str) -> None:
345
- uploaded_file = st.file_uploader(
346
- "Choose a CSV or Excel File", type=['csv', 'xls', 'xlsx'])
347
-
348
- uploaded = uploaded_file is not None
349
-
350
- upload_bulk_predict = st.button('Predict', type='primary',
351
- help='Upload a csv/excel file to make predictions', disabled=not uploaded, key='upload_bulk_predict')
352
- df = None
353
- if upload_bulk_predict and uploaded:
354
- df_test_raw, _ = await get_test_data()
355
- # Uploadfile is a "file-like" object is accepted
356
- try:
357
- try:
358
- df = pd.read_csv(uploaded_file)
359
- except Exception:
360
- df = pd.read_excel(uploaded_file)
361
-
362
- df_columns = set(df.columns)
363
- df_test_columns = set(df_test_raw.columns)
364
- df_schema = df.dtypes
365
- df_test_schema = df_test_raw.dtypes
366
-
367
- if df_columns != df_test_columns or not df_schema.equals(df_test_schema):
368
- df = None
369
- raise Exception
370
- else:
371
- # Clean dataframe
372
- janitor = Janitor()
373
- df_clean = janitor.clean_dataframe(df)
374
-
375
- df = await make_predictions(
376
- model_endpoint, df_uploaded=df, df_uploaded_clean=df_clean)
377
-
378
- except Exception:
379
- st.subheader('Data template')
380
- data_template = df_test_raw[:3]
381
- st.dataframe(data_template)
382
- csv = await convert_df(data_template)
383
- message_1 = 'Upload a valid csv or excel file.'
384
- message_2 = f"{message_1.split('.')[0]} with the columns and schema of the above data template."
385
- icon = '๐Ÿ˜ž'
386
- st.toast(message_1, icon=icon)
387
-
388
- st.download_button(
389
- label='Download template',
390
- data=csv,
391
- file_name='Data template.csv',
392
- mime="text/csv",
393
- type='secondary',
394
- key='download-data-template'
395
- )
396
- st.info('Download the above template for use as a baseline structure.')
397
-
398
- # Display explander to show the data dictionary
399
- with st.expander("Expand to see the data dictionary", icon="๐Ÿ’ก"):
400
- st.subheader("Data dictionary")
401
- st.markdown(markdown_table_all)
402
- st.warning(message_2, icon=icon)
403
-
404
- return df
405
-
406
-
407
- async def do_bulk_prediction(model_endpoint: str) -> None:
408
- if st.session_state.get('search_patient', False):
409
- await search_patient_form(model_endpoint)
410
- else:
411
- # File uploader
412
- await bulk_upload_widget(model_endpoint)
413
-
414
-
415
- async def show_bulk_predictions(df: pd.DataFrame) -> None:
416
- if df is not None:
417
- st.subheader("Bulk predictions ๐Ÿ”ฎ", divider=True)
418
- st.dataframe(df.astype(str))
419
-
420
- csv = await convert_df(df)
421
- message = 'The predictions are ready for download.'
422
- icon = 'โฌ‡๏ธ'
423
- st.toast(message, icon=icon)
424
- st.info(message, icon=icon)
425
- st.download_button(
426
- label='Download predictions',
427
- data=csv,
428
- file_name='Bulk prediction.csv',
429
- mime="text/csv",
430
- type='secondary',
431
- key='download-bulk-prediction'
432
- )
433
-
434
- # Set bulk prediction df to None
435
- st.session_state['bulk_prediction_df'] = None
436
-
437
-
438
- async def sidebar(sidebar_type: str) -> st.sidebar:
439
- return st.session_state.update({'sidebar': sidebar_type})
440
-
441
-
442
- async def main():
443
- st.title("๐Ÿค– Predict Sepsis ๐Ÿฆ ")
444
-
445
- # Navigation
446
- await navigation()
447
-
448
- st.sidebar.toggle("Looking for a patient?", value=st.session_state.get(
449
- 'search_patient', False), key='search_patient')
450
-
451
- selected_model = await select_model()
452
- model_endpoint = await endpoint(selected_model)
453
-
454
- selected_predict_tab = st.session_state.get('selected_predict_tab')
455
- default = 1 if selected_predict_tab is None else selected_predict_tab
456
-
457
- with st.spinner('A little house keeping...'):
458
- time.sleep(st.session_state.get('sleep', 1.5))
459
- chosen_id = stx.tab_bar(data=[
460
- stx.TabBarItemData(id=1, title='๐Ÿ”ฌ Predict', description=''),
461
- stx.TabBarItemData(id=2, title='๐Ÿ”ฎ Bulk predict',
462
- description=''),
463
- ], default=default)
464
- st.session_state['sleep'] = 0
465
-
466
- if chosen_id == '1':
467
- await sidebar('single_prediction')
468
- await do_single_prediction(model_endpoint)
469
- await show_prediction()
470
-
471
- elif chosen_id == '2':
472
- await sidebar('bulk_prediction')
473
- df_with_predictions = await do_bulk_prediction(model_endpoint)
474
- if df_with_predictions is None:
475
- df_with_predictions = st.session_state.get(
476
- 'bulk_prediction_df', None)
477
- await show_bulk_predictions(df_with_predictions)
478
-
479
- # Add footer
480
- await footer()
481
-
482
-
483
- if __name__ == "__main__":
484
- asyncio.run(main())
 
1
+ import os
2
+ import time
3
+ import httpx
4
+ import string
5
+ import random
6
+ import datetime as dt
7
+ from dotenv import load_dotenv
8
+
9
+ import streamlit as st
10
+ import extra_streamlit_components as stx
11
+
12
+ import asyncio
13
+ from aiocache import cached, Cache
14
+
15
+ import pandas as pd
16
+ from typing import Optional, Callable
17
+
18
+ from config import ENV_PATH, BEST_MODELS, TEST_FILE, TEST_FILE_URL, HISTORY_FILE, markdown_table_all
19
+
20
+ from utils.navigation import navigation
21
+ from utils.footer import footer
22
+ from utils.janitor import Janitor
23
+
24
+
25
+ # Load ENV
26
+ load_dotenv(ENV_PATH) # API_URL
27
+
28
+ # Set page configuration
29
+ st.set_page_config(
30
+ page_title="Homepage",
31
+ page_icon="๐Ÿค–",
32
+ layout="wide",
33
+ initial_sidebar_state='auto'
34
+ )
35
+
36
+
37
+ @cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_savedataset')
38
+ # @st.cache_data(show_spinner="Saving datasets...") # Streamlit cache is yet to support async functions
39
+ async def save_dataset(df: pd.DataFrame, filepath, csv=True) -> None:
40
+ async def save(df: pd.DataFrame, file):
41
+ return df.to_csv(file, index=False) if csv else df.to_excel(file, index=False)
42
+
43
+ async def read(file):
44
+ return pd.read_csv(file) if csv else pd.read_excel(file)
45
+
46
+ async def same_dfs(df: pd.DataFrame, df2: pd.DataFrame):
47
+ return df.equals(df2)
48
+
49
+ if not os.path.isfile(filepath): # Save if file does not exists
50
+ await save(df, filepath)
51
+ else: # Save if data are not same
52
+ df_old = await read(filepath)
53
+ if not await same_dfs(df, df_old):
54
+ await save(df, filepath)
55
+
56
+
57
+ @cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_testdata')
58
+ async def get_test_data():
59
+ try:
60
+ df_test_raw = pd.read_csv(TEST_FILE_URL)
61
+ await save_dataset(df_test_raw, TEST_FILE, csv=True)
62
+ except Exception:
63
+ df_test_raw = pd.read_csv(TEST_FILE)
64
+
65
+ # Some house keeping, clean df
66
+ df_test = df_test_raw.copy()
67
+ janitor = Janitor()
68
+ df_test = janitor.clean_dataframe(df_test) # Cleaned
69
+
70
+ return df_test_raw, df_test
71
+
72
+
73
+ # Function for selecting models
74
+ async def select_model() -> str:
75
+ col1, _ = st.columns(2)
76
+ with col1:
77
+ selected_model = st.selectbox(
78
+ 'Select a model', options=BEST_MODELS, key='selected_model')
79
+
80
+ return selected_model
81
+
82
+
83
+ async def endpoint(model: str) -> str:
84
+ api_url = os.getenv("API_URL")
85
+ model_endpoint = f"{api_url}={model}"
86
+ return model_endpoint
87
+
88
+
89
+ # Function for making prediction
90
+ async def make_prediction(model_endpoint) -> Optional[pd.DataFrame]:
91
+
92
+ test_data = await get_test_data()
93
+ _, df_test = test_data
94
+
95
+ df: pd.DataFrame = None
96
+ search_patient = st.session_state.get('search_patient', False)
97
+ search_patient_id = st.session_state.get('search_patient_id', False)
98
+ manual_patient_id = st.session_state.get('manual_patient_id', False)
99
+ if isinstance(search_patient_id, str) and search_patient_id: # And not empty string
100
+ search_patient_id = [search_patient_id]
101
+ if search_patient and search_patient_id: # Search Form df and a patient was selected
102
+ mask = df_test['id'].isin(search_patient_id)
103
+ df_form = df_test[mask]
104
+ df = df_form.copy()
105
+ elif not (search_patient or search_patient_id) and manual_patient_id: # Manual form df
106
+ columns = ['manual_patient_id', 'prg', 'pl', 'pr', 'sk',
107
+ 'ts', 'm11', 'bd2', 'age', 'insurance']
108
+ data = {c: [st.session_state.get(c)] for c in columns}
109
+ data['insurance'] = [1 if i == 'Yes' else 0 for i in data['insurance']]
110
+
111
+ # Make a DataFrame
112
+ df = pd.DataFrame(data).rename(
113
+ columns={'manual_patient_id': 'id'})
114
+ columns_int = ['prg', 'pl', 'pr', 'sk', 'ts', 'age']
115
+ columns_float = ['m11', 'bd2']
116
+
117
+ df[columns_int] = df[columns_int].astype(int)
118
+ df[columns_float] = df[columns_float].astype(float)
119
+ else: # Form did not send a patient
120
+ message = 'You must choose valid patient(s) from the select box.'
121
+ icon = '๐Ÿ˜ž'
122
+ st.toast(message, icon=icon)
123
+ st.warning(message, icon=icon)
124
+
125
+ if df is not None:
126
+ try:
127
+ # JSON data
128
+ data = df.to_dict(orient='list')
129
+
130
+ # Send POST request with JSON data using the json parameter
131
+ async with httpx.AsyncClient() as client:
132
+ response = await client.post(model_endpoint, json=data, timeout=30)
133
+ response.raise_for_status() # Ensure we catch any HTTP errors
134
+
135
+ if (response.status_code == 200):
136
+ pred_prob = (response.json()['result'])
137
+ prediction = pred_prob['prediction'][0]
138
+ probability = pred_prob['probability'][0]
139
+
140
+ # Store results in session state
141
+ st.session_state['prediction'] = prediction
142
+ st.session_state['probability'] = probability
143
+ df['prediction'] = prediction
144
+ df['probability (%)'] = probability
145
+ df['time_of_prediction'] = pd.Timestamp(dt.datetime.now())
146
+ df['model_used'] = st.session_state['selected_model']
147
+
148
+ df.to_csv(HISTORY_FILE, mode='a',
149
+ header=not os.path.isfile(HISTORY_FILE))
150
+ except Exception as e:
151
+ st.error(f'๐Ÿ˜ž Unable to connect to the API server. {e}')
152
+
153
+ return df
154
+
155
+
156
+ async def convert_string(df: pd.DataFrame, string: str) -> str:
157
+ return string.upper() if all(col.isupper() for col in df.columns) else string
158
+
159
+
160
+ async def make_predictions(model_endpoint, df_uploaded=None, df_uploaded_clean=None) -> Optional[pd.DataFrame]:
161
+
162
+ df: pd.DataFrame = None
163
+ search_patient = st.session_state.get('search_patient', False)
164
+ patient_id_bulk = st.session_state.get('patient_id_bulk', False)
165
+ upload_bulk_predict = st.session_state.get('upload_bulk_predict', False)
166
+ if search_patient and patient_id_bulk: # Search Form df and a patient was selected
167
+ _, df_test = await get_test_data()
168
+ mask = df_test['id'].isin(patient_id_bulk)
169
+ df_bulk: pd.DataFrame = df_test[mask]
170
+ df = df_bulk.copy()
171
+
172
+ elif not (search_patient or patient_id_bulk) and upload_bulk_predict: # Upload widget df
173
+ df = df_uploaded_clean.copy()
174
+ else: # Form did not send a patient
175
+ message = 'You must choose valid patient(s) from the select box.'
176
+ icon = '๐Ÿ˜ž'
177
+ st.toast(message, icon=icon)
178
+ st.warning(message, icon=icon)
179
+
180
+ if df is not None: # df should be set by form input or upload widget
181
+ try:
182
+ # JSON data
183
+ data = df.to_dict(orient='list')
184
+
185
+ # Send POST request with JSON data using the json parameter
186
+ async with httpx.AsyncClient() as client:
187
+ response = await client.post(model_endpoint, json=data, timeout=30)
188
+ response.raise_for_status() # Ensure we catch any HTTP errors
189
+
190
+ if (response.status_code == 200):
191
+ pred_prob = (response.json()['result'])
192
+ predictions = pred_prob['prediction']
193
+ probabilities = pred_prob['probability']
194
+
195
+ # Add columns sepsis, probability, time, and model used to uploaded df and form df
196
+
197
+ async def add_columns(df):
198
+ df[await convert_string(df, 'sepsis')] = predictions
199
+ df[await convert_string(df, 'probability_(%)')] = probabilities
200
+ df[await convert_string(df, 'time_of_prediction')
201
+ ] = pd.Timestamp(dt.datetime.now())
202
+ df[await convert_string(df, 'model_used')
203
+ ] = st.session_state['selected_model']
204
+
205
+ return df
206
+
207
+ # Form df if search patient is true or df from Uploaded data
208
+ if search_patient:
209
+ df = await add_columns(df)
210
+
211
+ df.to_csv(HISTORY_FILE, mode='a', header=not os.path.isfile(
212
+ HISTORY_FILE)) # Save only known patients
213
+
214
+ else:
215
+ df = await add_columns(df_uploaded) # Raw, No cleaning
216
+
217
+ # Store df with prediction results in session state
218
+ st.session_state['bulk_prediction_df'] = df
219
+ except Exception as e:
220
+ st.error(f'๐Ÿ˜ž Unable to connect to the API server. {e}')
221
+
222
+ return df
223
+
224
+
225
+ def on_click(func: Callable, model_endpoint: str):
226
+ async def handle_click():
227
+ await func(model_endpoint)
228
+
229
+ loop = asyncio.new_event_loop()
230
+ asyncio.set_event_loop(loop)
231
+ loop.run_until_complete(handle_click())
232
+ loop.close()
233
+
234
+
235
+ async def search_patient_form(model_endpoint: str) -> None:
236
+ test_data = await get_test_data()
237
+ _, df_test = test_data
238
+
239
+ patient_ids = df_test['id'].unique().tolist()+['']
240
+ if st.session_state['sidebar'] == 'single_prediction':
241
+ with st.form('search_patient_id_form'):
242
+ col1, _ = st.columns(2)
243
+ with col1:
244
+ st.write('#### Patient ID ๐Ÿค’')
245
+ st.selectbox(
246
+ 'Search a patient', options=patient_ids, index=len(patient_ids)-1, key='search_patient_id')
247
+ st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
248
+ func=make_prediction, model_endpoint=model_endpoint))
249
+ else:
250
+ with st.form('search_patient_id_bulk_form'):
251
+ col1, _ = st.columns(2)
252
+ with col1:
253
+ st.write('#### Patient ID ๐Ÿค’')
254
+ st.multiselect(
255
+ 'Search a patient', options=patient_ids, default=None, key='patient_id_bulk')
256
+ st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
257
+ func=make_predictions, model_endpoint=model_endpoint))
258
+
259
+
260
+ async def gen_random_patient_id() -> str:
261
+ numbers = ''.join(random.choices(string.digits, k=6))
262
+ letters = ''.join(random.choices(string.ascii_lowercase, k=4))
263
+ return f"ICU{numbers}-gen-{letters}"
264
+
265
+
266
+ async def manual_patient_form(model_endpoint) -> None:
267
+ with st.form('manual_patient_form'):
268
+
269
+ col1, col2, col3 = st.columns(3)
270
+
271
+ with col1:
272
+ st.write('### Patient Demographics ๐Ÿ›Œ')
273
+ st.text_input(
274
+ 'ID', value=await gen_random_patient_id(), key='manual_patient_id')
275
+ st.number_input('Age: patients age (years)', min_value=0,
276
+ max_value=100, step=1, key='age')
277
+ st.selectbox('Insurance: If a patient holds a valid insurance card', options=[
278
+ 'Yes', 'No'], key='insurance')
279
+
280
+ with col2:
281
+ st.write('### Vital Signs ๐Ÿฉบ')
282
+ st.number_input('BMI (weight in kg/(height in m)^2', min_value=10.0,
283
+ format="%.2f", step=1.00, key='m11')
284
+ st.number_input(
285
+ 'Blood Pressure (mm Hg)', min_value=10.0, format="%.2f", step=1.00, key='pr')
286
+ st.number_input(
287
+ 'PRG (plasma glucose)', min_value=10.0, format="%.2f", step=1.00, key='prg')
288
+
289
+ with col3:
290
+ st.write('### Blood Work ๐Ÿ’‰')
291
+ st.number_input(
292
+ 'PL: Blood Work Result-1 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='pl')
293
+ st.number_input(
294
+ 'SK: Blood Work Result 2 (mm)', min_value=10.0, format="%.2f", step=1.00, key='sk')
295
+ st.number_input(
296
+ 'TS: Blood Work Result-3 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='ts')
297
+ st.number_input(
298
+ 'BD2: Blood Work Result-4 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='bd2')
299
+
300
+ st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
301
+ func=make_prediction, model_endpoint=model_endpoint))
302
+
303
+
304
+ async def do_single_prediction(model_endpoint: str) -> None:
305
+ if st.session_state.get('search_patient', False):
306
+ await search_patient_form(model_endpoint)
307
+ else:
308
+ await manual_patient_form(model_endpoint)
309
+
310
+
311
+ async def show_prediction() -> None:
312
+ final_prediction = st.session_state.get('prediction', None)
313
+ final_probability = st.session_state.get('probability', None)
314
+
315
+ if final_prediction is None:
316
+ st.markdown('#### Prediction will show below! ๐Ÿ”ฌ')
317
+ st.divider()
318
+ else:
319
+ st.markdown('#### Prediction! ๐Ÿ”ฌ')
320
+ st.divider()
321
+ if final_prediction.lower() == 'positive':
322
+ st.toast("Sepsis alert!", icon='๐Ÿฆ ')
323
+ message = f"It is **{final_probability:.2f} %** likely that the patient will develop **sepsis.**"
324
+ st.warning(message, icon='๐Ÿ˜ž')
325
+ time.sleep(5)
326
+ st.toast(message)
327
+ else:
328
+ st.toast("Continous monitoring", icon='๐Ÿ”ฌ')
329
+ message = f"The patient will **not** develop sepsis with a likelihood of **{final_probability:.2f}%**."
330
+ st.success(message, icon='๐Ÿ˜Š')
331
+ time.sleep(1)
332
+ st.toast(message)
333
+
334
+ # Set prediction and probability to None
335
+ st.session_state['prediction'] = None
336
+ st.session_state['probability'] = None
337
+
338
+
339
+ # @st.cache_data(show_spinner=False) Caching results from async functions buggy
340
+ async def convert_df(df: pd.DataFrame):
341
+ return df.to_csv(index=False)
342
+
343
+
344
+ async def bulk_upload_widget(model_endpoint: str) -> None:
345
+ uploaded_file = st.file_uploader(
346
+ "Choose a CSV or Excel File", type=['csv', 'xls', 'xlsx'])
347
+
348
+ uploaded = uploaded_file is not None
349
+
350
+ upload_bulk_predict = st.button('Predict', type='primary',
351
+ help='Upload a csv/excel file to make predictions', disabled=not uploaded, key='upload_bulk_predict')
352
+ df = None
353
+ if upload_bulk_predict and uploaded:
354
+ df_test_raw, _ = await get_test_data()
355
+ # Uploadfile is a "file-like" object is accepted
356
+ try:
357
+ try:
358
+ df = pd.read_csv(uploaded_file)
359
+ except Exception:
360
+ df = pd.read_excel(uploaded_file)
361
+
362
+ df_columns = set(df.columns)
363
+ df_test_columns = set(df_test_raw.columns)
364
+ df_schema = df.dtypes
365
+ df_test_schema = df_test_raw.dtypes
366
+
367
+ if df_columns != df_test_columns or not df_schema.equals(df_test_schema):
368
+ df = None
369
+ raise Exception
370
+ else:
371
+ # Clean dataframe
372
+ janitor = Janitor()
373
+ df_clean = janitor.clean_dataframe(df)
374
+
375
+ df = await make_predictions(
376
+ model_endpoint, df_uploaded=df, df_uploaded_clean=df_clean)
377
+
378
+ except Exception:
379
+ st.subheader('Data template')
380
+ data_template = df_test_raw[:3]
381
+ st.dataframe(data_template)
382
+ csv = await convert_df(data_template)
383
+ message_1 = 'Upload a valid csv or excel file.'
384
+ message_2 = f"{message_1.split('.')[0]} with the columns and schema of the above data template."
385
+ icon = '๐Ÿ˜ž'
386
+ st.toast(message_1, icon=icon)
387
+
388
+ st.download_button(
389
+ label='Download template',
390
+ data=csv,
391
+ file_name='Data template.csv',
392
+ mime="text/csv",
393
+ type='secondary',
394
+ key='download-data-template'
395
+ )
396
+ st.info('Download the above template for use as a baseline structure.')
397
+
398
+ # Display explander to show the data dictionary
399
+ with st.expander("Expand to see the data dictionary", icon="๐Ÿ’ก"):
400
+ st.subheader("Data dictionary")
401
+ st.markdown(markdown_table_all)
402
+ st.warning(message_2, icon=icon)
403
+
404
+ return df
405
+
406
+
407
+ async def do_bulk_prediction(model_endpoint: str) -> None:
408
+ if st.session_state.get('search_patient', False):
409
+ await search_patient_form(model_endpoint)
410
+ else:
411
+ # File uploader
412
+ await bulk_upload_widget(model_endpoint)
413
+
414
+
415
+ async def show_bulk_predictions(df: pd.DataFrame) -> None:
416
+ if df is not None:
417
+ st.subheader("Bulk predictions ๐Ÿ”ฎ", divider=True)
418
+ st.dataframe(df.astype(str))
419
+
420
+ csv = await convert_df(df)
421
+ message = 'The predictions are ready for download.'
422
+ icon = 'โฌ‡๏ธ'
423
+ st.toast(message, icon=icon)
424
+ st.info(message, icon=icon)
425
+ st.download_button(
426
+ label='Download predictions',
427
+ data=csv,
428
+ file_name='Bulk prediction.csv',
429
+ mime="text/csv",
430
+ type='secondary',
431
+ key='download-bulk-prediction'
432
+ )
433
+
434
+ # Set bulk prediction df to None
435
+ st.session_state['bulk_prediction_df'] = None
436
+
437
+
438
+ async def sidebar(sidebar_type: str) -> st.sidebar:
439
+ return st.session_state.update({'sidebar': sidebar_type})
440
+
441
+
442
+ async def main():
443
+ st.title("๐Ÿค– Predict Sepsis ๐Ÿฆ ")
444
+
445
+ # Navigation
446
+ await navigation()
447
+
448
+ st.sidebar.toggle("Looking for a patient?", value=st.session_state.get(
449
+ 'search_patient', False), key='search_patient')
450
+
451
+ selected_model = await select_model()
452
+ model_endpoint = await endpoint(selected_model)
453
+
454
+ selected_predict_tab = st.session_state.get('selected_predict_tab')
455
+ default = 1 if selected_predict_tab is None else selected_predict_tab
456
+
457
+ with st.spinner('A little house keeping...'):
458
+ time.sleep(st.session_state.get('sleep', 1.5))
459
+ chosen_id = stx.tab_bar(data=[
460
+ stx.TabBarItemData(id=1, title='๐Ÿ”ฌ Predict', description=''),
461
+ stx.TabBarItemData(id=2, title='๐Ÿ”ฎ Bulk predict',
462
+ description=''),
463
+ ], default=default)
464
+ st.session_state['sleep'] = 0
465
+
466
+ if chosen_id == '1':
467
+ await sidebar('single_prediction')
468
+ await do_single_prediction(model_endpoint)
469
+ await show_prediction()
470
+
471
+ elif chosen_id == '2':
472
+ await sidebar('bulk_prediction')
473
+ df_with_predictions = await do_bulk_prediction(model_endpoint)
474
+ if df_with_predictions is None:
475
+ df_with_predictions = st.session_state.get(
476
+ 'bulk_prediction_df', None)
477
+ await show_bulk_predictions(df_with_predictions)
478
+
479
+ # Add footer
480
+ await footer()
481
+
482
+
483
+ if __name__ == "__main__":
484
+ asyncio.run(main())