import streamlit as st import json import random import pandas as pd import pickle # set page configuration to wide mode st.set_page_config(layout="wide") st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): import adrd try: ckpt_path = './ckpt_swinunetr_stripped_MNI.pt' model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') except: ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt' model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') return model @st.cache_resource def load_nacc_data(): from data.dataset_csv import CSVDataset dat = CSVDataset( dat_file = "./data/test.csv", cnf_file = "./data/input_meta_info.csv" ) return dat model = load_model() dat_tst = load_nacc_data() def predict_proba(data_dict): pred_dict = model.predict_proba([data_dict])[1][0] return pred_dict # load NACC testing data from data.dataset_csv import CSVDataset dat_tst = CSVDataset( dat_file = "./data/test.csv", cnf_file = "./data/input_meta_info.csv" ) # initialize session state for the text input if it's not already set if 'input_text' not in st.session_state: st.session_state.input_text = "" # section 1 st.markdown("#### About") st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.") # section 2 st.markdown("#### Demo") st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.") # layout layout_l, layout_r = st.columns([1, 1]) # create a form for user input with layout_l: with st.form("json_input_form"): json_input = st.text_area( "Please enter JSON-formatted input features:", value = st.session_state.input_text, height = 300 ) # create three columns left_col, middle_col, right_col = st.columns([3, 4, 1]) with left_col: sample_button = st.form_submit_button("Random case") with right_col: submit_button = st.form_submit_button("Predict") with open('./data/nacc_variable_mappings.pkl', 'rb') as file: nacc_mapping = pickle.load(file) def convert_dictionary(original_dict, mappings): transformed_dict = {} for key, value in original_dict.items(): if key in mappings: new_key, transform_map = mappings[key] # If the value needs to be transformed if value in transform_map: transformed_value = transform_map[value] else: transformed_value = value # Keep the original value if no transformation is needed transformed_dict[new_key] = transformed_value return transformed_dict if sample_button: idx = random.randint(0, len(dat_tst) - 1) random_case = dat_tst[idx][0] st.session_state.input_text = json.dumps(random_case, indent=2) # reset input text after form processing to show updated text in the input box if 'input_text' in st.session_state: st.experimental_rerun() elif submit_button: try: # Parse the JSON input into a Python dictionary data_dict = json.loads(json_input) data_dict = convert_dictionary(data_dict, nacc_mapping) # print(data_dict) pred_dict = predict_proba(data_dict) with layout_r: st.write("Predicted probabilities:") st.code(json.dumps(pred_dict, indent=2)) except json.JSONDecodeError as e: # Handle JSON parsing errors st.error(f"An error occurred: {e}") # section 3 st.markdown("#### Feature Table") df_input_meta_info = pd.read_csv('./data/input_meta_info.csv') st.table(df_input_meta_info)