import streamlit as st import pandas as pd import logging from deeploy import Client from utils import get_request_body, get_fake_certainty, get_model_url, get_random_suspicious_transaction from utils import get_explainability_texts, get_explainability_values, send_evaluation, get_comment_explanation from utils import COL_NAMES, feature_texts from utils import create_data_input_table, create_table, ChangeButtonColour, get_weights logging.basicConfig(level=logging.INFO) st.set_page_config(layout="wide") st.title("Smart AML:tm:") st.divider() data = pd.read_pickle("data/preprocessed_data.pkl") if 'predict_button_clicked' not in st.session_state: st.session_state.predict_button_clicked = False if "submitted_disabled" not in st.session_state: st.session_state.submitted_disabled = False if "disabled" not in st.session_state: st.session_state.disabled = False def disabled(): st.session_state.disabled = True def rerun(): st.session_state.predict_button_clicked = True st.session_state.submitted_disabled = False def submitted_disabled(): st.session_state.submitted_disabled = True st.markdown(""" """, unsafe_allow_html=True) with st.sidebar: # Add deeploy logo st.image("deeploy_logo.png", width=270) # Ask for model URL and token host = st.text_input("Host (changing is optional)", "app.deeploy.ml") model_url, workspace_id, deployment_id = get_model_url() deployment_token = st.text_input("Deeploy Model Token", "my-secret-token") # my-secret-token if deployment_token == "my-secret-token": st.warning( "Please enter Deeploy API token." ) else: st.button("Get suspicious transaction", key="predict_button", help="Click to get a suspicious transaction", use_container_width=True, on_click=disabled, disabled=st.session_state.disabled ) #on_click=lambda: st.experimental_rerun() ChangeButtonColour("Get suspicious transaction", '#FFFFFF', "#00052D")#'#FFFFFF', "#00052D" # define client optsions and instantiate client client_options = { "host": host, "deployment_token": deployment_token, "workspace_id": workspace_id, } client = Client(**client_options) if 'predict_button' not in st.session_state: st.session_state.predict_button = False if st.session_state.predict_button: # and not st.session_state.predict_button_clicked st.session_state.predict_button_clicked = True if 'got_explanation' not in st.session_state: st.session_state.got_explanation = False if st.session_state.predict_button_clicked: try: with st.spinner("Loading..."): datapoint_pd = get_random_suspicious_transaction(data) request_body = get_request_body(datapoint_pd) # Call the explain endpoint as it also includes the prediction exp = client.explain(request_body=request_body, deployment_id=deployment_id) # request_log_id = exp["requestLogId"] # prediction_log_id = exp["predictionLogIds"][0] st.session_state.shap_values = exp['explanations'][0]['shap_values'] st.session_state.request_log_id = exp["requestLogId"] st.session_state.prediction_log_id = exp["predictionLogIds"][0] st.session_state.datapoint_pd = datapoint_pd certainty = get_fake_certainty() st.session_state.certainty = certainty st.session_state.got_explanation = True st.session_state.predict_button_clicked = False except Exception as e: logging.error(e) st.error( "Failed to retrieve the prediction or explanation." + "Check whether you are using the right model URL and Token. " + "Contact Deeploy if the problem persists." ) if not st.session_state.got_explanation: st.info( "Fill in left hand side and click on button to observe a potential fraudulent transaction" ) if st.session_state.got_explanation: shap_values = st.session_state.shap_values request_log_id = st.session_state.request_log_id prediction_log_id = st.session_state.prediction_log_id datapoint_pd = st.session_state.datapoint_pd certainty = st.session_state.certainty col1, col2 = st.columns(2) with col1: create_data_input_table(datapoint_pd, COL_NAMES) with col2: st.subheader('AML Model Hit') # st.success(f'{certainty}') # st.metric(label='Model Certainty', value=certainty) # style_metric_cards(border_left_color='#00052D', box_shadow=False) # # st.markdown('#### Model Certainty') st.metric(label='Model Certainty', value=certainty) explainability_texts, sorted_indices = get_explainability_texts(shap_values, feature_texts) weights = get_weights(shap_values, sorted_indices) explainability_values = get_explainability_values(sorted_indices, datapoint_pd) create_table(explainability_texts, explainability_values, weights, 'Important Suspicious Factors') st.subheader("") # st.markdown("