|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import copy |
|
from oocsi_source import OOCSI |
|
from uuid import uuid4 |
|
from streamlit_extras.switch_page_button import switch_page |
|
import random |
|
|
|
import dice_ml |
|
from dice_ml.utils import helpers |
|
import xgboost as xgb |
|
import matplotlib.pyplot as plt |
|
from sklearn.ensemble import RandomForestClassifier |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'profile3' not in st.session_state: |
|
st.session_state.pages.remove("counterfactual") |
|
st.session_state.profile3= 'deleted' |
|
if (len(st.session_state.pages)>0): |
|
st.session_state.nextPage3 = random.randint(0, len(st.session_state.pages)-1) |
|
st.session_state.lastQuestion= 'no' |
|
else: |
|
st.session_state.lastQuestion= 'yes' |
|
|
|
|
|
if 'index3' not in st.session_state: |
|
st.session_state.index3= 0 |
|
|
|
|
|
if 'profileIndex' not in st.session_state: |
|
st.session_state.profileIndex= st.session_state.profileIndices[st.session_state.index3] |
|
|
|
header1, header2, header3 = st.columns([1,2,1]) |
|
characteristics1, characteristics2, characteristics3 = st.columns([1,2,1]) |
|
prediction1, prediction2, prediction3 =st.columns([1,2,1]) |
|
explanation1, explanation2, explanation3 = st.columns([1,10,1]) |
|
footer1, footer2, footer3 =st.columns([1,2,1]) |
|
evaluation1, evaluation2, evaluation3 = st.columns([1,2,1]) |
|
|
|
|
|
|
|
name= st.session_state.X_test_names.loc[st.session_state.profileIndex, "Name"] |
|
|
|
|
|
@st.cache_resource |
|
def trainModel(X_train,Y_train): |
|
model_1 = RandomForestClassifier().fit(X_train, Y_train) |
|
return model_1 |
|
|
|
|
|
@st.cache_resource |
|
def getcounterfactual_values(_model,X_prediction, X_train): |
|
|
|
train_df = pd.read_csv('assets/train_df.csv') |
|
continous_col=["Age", 'Fare', 'Siblings_spouses', 'Title', 'Parents_children','relatives' ] |
|
|
|
|
|
dice_data = dice_ml.Data(dataframe=train_df,continuous_features=continous_col, outcome_name='Survived') |
|
dice_model= dice_ml.Model(model=_model, backend="sklearn") |
|
explainer = dice_ml.Dice(dice_data, dice_model, method="random") |
|
return explainer |
|
|
|
|
|
|
|
def Counterfactualsplot(X_test, explainer): |
|
e1 = explainer.generate_counterfactuals( |
|
X_test[1:2],total_CFs=4, desired_class="opposite", |
|
features_to_vary = ['Age','Pclass', 'Sex','Siblings_spouses', 'Parents_children', 'Embarked', 'relatives', 'Title'] ) |
|
e1.cf_examples_list[0].final_cfs_df.to_csv(path_or_buf=rf'assets\counterfactuals_{name}.csv', index=False) |
|
counter_csv = pd.read_csv(f'assets\counterfactuals_{name}.csv') |
|
return st.dataframe(counter_csv, width=10000) |
|
|
|
with header2: |
|
st.header("Explanation - Counterfactuals") |
|
st.markdown('''A counterfactual explanation describes a situation where if a specific event had not occurred, the conclusion would have been different |
|
and a specific outcome would not have occurred. In machine learning, counterfactuals are used to explain prediction of individuals instances. The prediction |
|
of the model will be analysed and certain conditions/features that created this prediction will be modified to obtain an different outcome for the model.''') |
|
|
|
st.markdown('''As displayed in the graph below, the relation betwwen the inputs andthe prediciton is modified by the feature values that creates a simple causal |
|
relationshhip betwen inputs and predictions. |
|
''') |
|
|
|
st.image('assets/counterfactual.jpg', caption = 'Causal relation between inputs and predictions', use_column_width = 'always' ) |
|
|
|
st.markdown('''A counterfactual explanation of a prediction will then describe the smallest amount of change that is necessary to make to change the output |
|
prediction to a predefine one.''') |
|
st.subheader(name, anchor='top') |
|
|
|
|
|
random_forest= trainModel(st.session_state.X_train, st.session_state.Y_train) |
|
|
|
with characteristics2: |
|
|
|
data = st.session_state.X_test.iloc[st.session_state.profileIndex].values.reshape(1, -1) |
|
|
|
df = pd.DataFrame(data, columns=st.session_state.X_test.columns) |
|
st.dataframe(df) |
|
|
|
|
|
with prediction2: |
|
|
|
prediction = random_forest.predict(st.session_state.X_test.iloc[st.session_state.profileIndex].values.reshape(1, -1)) |
|
prediction_all = random_forest.predict(st.session_state.X_test.values) |
|
probability = random_forest.predict_proba(st.session_state.X_test.iloc[st.session_state.profileIndex].values.reshape(1, -1)) |
|
if prediction == 0: |
|
prob = round((probability[0][0]*100),2) |
|
st.markdown("The model predicts with {}% probability that {} will :red[**not survive**]".format(prob, name) ) |
|
else: |
|
prob = round((probability[0][1]*100),2) |
|
st.markdown("The model predicts with {}% probability that {} will :green[**survive**]".format(prob, name) ) |
|
|
|
with explanation2: |
|
st.subheader("Explanation") |
|
st.markdown("counterfactual, more text here") |
|
|
|
|
|
|
|
|
|
explainer= getcounterfactual_values(random_forest, prediction_all, st.session_state.X_test) |
|
st.set_option('deprecation.showPyplotGlobalUse', False) |
|
e1=Counterfactualsplot(st.session_state.X_test, explainer) |
|
data_indices = pd.concat([d.reset_index(drop=True) for d in [st.session_state.ports_df, st.session_state.title_df, st.session_state.gender_df]], axis=1) |
|
st.dataframe(data_indices) |
|
|
|
with footer2: |
|
if (st.session_state.index3 < len(st.session_state.profileIndices)-1): |
|
if st.button("New profile"): |
|
st.session_state.index3 = st.session_state.index3+1 |
|
st.session_state.profileIndex = st.session_state.profileIndices[st.session_state.index3] |
|
st.experimental_rerun() |
|
else: |
|
def is_user_active(): |
|
if 'user_active3' in st.session_state.keys() and st.session_state['user_active3']: |
|
return True |
|
else: |
|
return False |
|
if is_user_active(): |
|
|
|
|
|
with st.form("my_form3", clear_on_submit=True): |
|
st.subheader("Evaluation") |
|
st.write("These questions only ask for your opinion about this specific explanation") |
|
q1 = st.select_slider( |
|
'**1**- From the explanation, I **understand** how the algorithm works:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q2 = st.select_slider( |
|
'**2**- This explanation of how the algorithm works is **satisfying**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q3 = st.select_slider( |
|
'**3**- This explanation of how the algorithm works has **sufficient detail**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q4 = st.select_slider( |
|
'**4**- This explanation of how the algorithm works seems **complete**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q5 = st.select_slider( |
|
'**5**- This explanation of how the algorithm works **tells me how to use it**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q6 = st.select_slider( |
|
'**6**- This explanation of how the algorithm works is **useful to my goals**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q7 = st.select_slider( |
|
'**7**- This explanation of the algorithm shows me how **accurate** the algorithm is:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q8 = st.select_slider( |
|
'**8**- This explanation lets me judge when I should **trust and not trust** the algorithm:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
|
|
submitted = st.form_submit_button("Submit") |
|
if submitted: |
|
|
|
st.session_state.oocsi.send('EngD_HAII', { |
|
'participant_ID': st.session_state.participantID, |
|
'type of explanation': 'counterfactual', |
|
'q1': q1, |
|
'q2': q2, |
|
'q3': q3, |
|
'q4': q4, |
|
'q5': q5, |
|
'q6': q6, |
|
'q7': q7, |
|
'q8': q8, |
|
|
|
}) |
|
if (st.session_state.lastQuestion =='yes'): |
|
switch_page('finalPage') |
|
else: |
|
st.session_state.profileIndex =st.session_state.profileIndices[0] |
|
switch_page(st.session_state.pages[st.session_state.nextPage3]) |
|
else: |
|
if st.button('Continue to evaluation'): |
|
st.session_state['user_active3']=True |
|
st.experimental_rerun() |
|
|
|
|