|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from oocsi_source import OOCSI |
|
from uuid import uuid4 |
|
from streamlit_extras.switch_page_button import switch_page |
|
import random |
|
import dtreeviz |
|
import xgboost as xgb |
|
from dtreeviz.trees import dtreeviz |
|
from sklearn.tree import DecisionTreeClassifier |
|
import graphviz as graphviz |
|
from sklearn.datasets import make_moons |
|
import base64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'profile2' not in st.session_state: |
|
st.session_state.pages.remove("DecisionTree") |
|
st.session_state.profile2= 'deleted' |
|
if (len(st.session_state.pages)>0): |
|
st.session_state.nextPage2 = random.randint(0, len(st.session_state.pages)-1) |
|
st.session_state.lastQuestion= 'no' |
|
else: |
|
st.session_state.lastQuestion= 'yes' |
|
|
|
|
|
if 'index2' not in st.session_state: |
|
st.session_state.index2= 0 |
|
|
|
|
|
if 'profileIndex' not in st.session_state: |
|
st.session_state.profileIndex= st.session_state.profileIndices[st.session_state.index2] |
|
|
|
name= st.session_state.X_test_names.loc[st.session_state.profileIndex, "Name"] |
|
|
|
|
|
|
|
header1, header2, header3 = st.columns([1,2,1]) |
|
characteristics1, characteristics2, characteristics3 = st.columns([1,2,1]) |
|
prediction1, prediction2, prediction3 =st.columns([1,2,1]) |
|
explanation1, explanation2, explanation3 = st.columns([1,2,1]) |
|
footer1, footer2, footer3 =st.columns([1,2,1]) |
|
evaluation1, evaluation2, evaluation3 = st.columns([1,2,1]) |
|
|
|
|
|
@st.cache_resource |
|
def loadData(): |
|
train_df = pd.read_csv('assets/train_df.csv') |
|
test_df = pd.read_csv('assets/test_df.csv') |
|
X_train = train_df.drop("Survived", axis=1) |
|
Y_train = train_df["Survived"] |
|
X_test = test_df.drop("PassengerId", axis=1).copy() |
|
return X_train, Y_train, X_test |
|
|
|
@st.cache_resource |
|
def trainModel(X_train,Y_train): |
|
model = xgb.XGBClassifier().fit(X_train, Y_train) |
|
return model |
|
|
|
|
|
def createTree(_model, X_train, Y_train, X_test): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
viz_model = dtreeviz(_model, |
|
X_train, Y_train, |
|
tree_index=0, |
|
feature_names=list(X_train.columns), |
|
target_name='Survived', |
|
class_names=['Dead', 'Alive'], |
|
X=X_test.iloc[st.session_state.profileIndex], |
|
|
|
show_just_path=True, |
|
|
|
) |
|
|
|
viz_model.save("/assets/images/prediction_path.svg") |
|
return viz_model |
|
|
|
def render_svg(svg): |
|
"""Renders the given svg string.""" |
|
b64 = base64.b64encode(svg.encode('utf-8')).decode("utf-8") |
|
html = r'<img src="data:image/svg+xml;base64,%s"/>' % b64 |
|
st.write(html, unsafe_allow_html=True) |
|
|
|
|
|
with header2: |
|
st.header("Explanation - Decision Tree") |
|
st.markdown('''Decision Tree models are a non-parametric supervised learning method |
|
commonly used for classification and regression. |
|
They are constructed using two kinf of elements: Nodes and branches. At each node (intersection), |
|
one of the data features is evaluated to split the observations into different paths. |
|
|
|
|
|
At typical decision example is shown in the graph below. |
|
''') |
|
|
|
st.image('assets/Decision_tree.jpg',caption = 'Example of a decision tree') |
|
|
|
st.markdown(''' The Root Node starts the graph. It is usually the variable that splits the more lcearly the data. |
|
Then, intermediate nodes are vsisble were different varaibales are evaluated but no final prediction is made yet. |
|
Finally, leaf nodes are present where the predicrtions (numerical of categoriacl) are made. |
|
|
|
For the Titanic dataset, the prediction will be whether the studied person survived the shipwreck. |
|
''') |
|
|
|
st.subheader(name) |
|
XGBmodel= trainModel(st.session_state.X_train, st.session_state.Y_train) |
|
|
|
|
|
|
|
with characteristics2: |
|
|
|
data = st.session_state.X_test.iloc[st.session_state.profileIndex].values.reshape(1, -1) |
|
|
|
df = pd.DataFrame(data, columns=st.session_state.X_test.columns) |
|
st.dataframe(df) |
|
|
|
|
|
with prediction2: |
|
prediction = XGBmodel.predict(st.session_state.X_test.iloc[st.session_state.profileIndex].values.reshape(1, -1)) |
|
probability = XGBmodel.predict_proba(st.session_state.X_test.iloc[st.session_state.profileIndex].values.reshape(1, -1)) |
|
if prediction == 0: |
|
prob = round((probability[0][0]*100),2) |
|
st.markdown("The model predicts with {}% probability that {} will :red[**not survive**]".format(prob, name) ) |
|
else: |
|
prob = round((probability[0][1]*100),2) |
|
st.markdown("The model predicts with {}% probability that {} will :green[**survive**]".format(prob, name) ) |
|
|
|
with explanation2: |
|
st.subheader("Visualization - Decision Tree") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.spinner("Please be patient, we are generating a new explanation"): |
|
viz_model = createTree(XGBmodel, st.session_state.X_train, st.session_state.Y_train, st.session_state.X_test) |
|
|
|
|
|
|
|
path = "/assets/images/prediction_path" + str(st.session_state.profileIndex) +".svg" |
|
|
|
with open("/assets/images/prediction_path.svg", "r") as f: |
|
svg = f.read() |
|
render_svg(svg) |
|
|
|
st.text("") |
|
|
|
with footer2: |
|
if (st.session_state.index2 < len(st.session_state.profileIndices)-1): |
|
if st.button("New profile"): |
|
st.session_state.index2 = st.session_state.index2+1 |
|
st.session_state.profileIndex = st.session_state.profileIndices[st.session_state.index2] |
|
st.experimental_rerun() |
|
else: |
|
def is_user_active(): |
|
if 'user_active2' in st.session_state.keys() and st.session_state['user_active2']: |
|
return True |
|
else: |
|
return False |
|
if is_user_active(): |
|
|
|
|
|
with st.form("my_form2", clear_on_submit=True): |
|
st.subheader("Evaluation") |
|
st.write("These questions only ask for your opinion about this specific explanation") |
|
q1 = st.select_slider( |
|
'**1**- From the explanation, I **understand** how the algorithm works:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q2 = st.select_slider( |
|
'**2**- This explanation of how the algorithm works is **satisfying**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q3 = st.select_slider( |
|
'**3**- This explanation of how the algorithm works has **sufficient detail**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q4 = st.select_slider( |
|
'**4**- This explanation of how the algorithm works seems **complete**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q5 = st.select_slider( |
|
'**5**- This explanation of how the algorithm works **tells me how to use it**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q6 = st.select_slider( |
|
'**6**- This explanation of how the algorithm works is **useful to my goals**:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q7 = st.select_slider( |
|
'**7**- This explanation of the algorithm shows me how **accurate** the algorithm is:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
q8 = st.select_slider( |
|
'**8**- This explanation lets me judge when I should **trust and not trust** the algorithm:', |
|
options=['totally disagree', 'disagree', 'neutral' , 'agree', 'totally agree']) |
|
|
|
|
|
submitted = st.form_submit_button("Submit") |
|
if submitted: |
|
|
|
st.session_state.oocsi.send('EngD_HAII', { |
|
'participant_ID': st.session_state.participantID, |
|
'type of explanation': 'Decision tree', |
|
'q1': q1, |
|
'q2': q2, |
|
'q3': q3, |
|
'q4': q4, |
|
'q5': q5, |
|
'q6': q6, |
|
'q7': q7, |
|
'q8': q8, |
|
|
|
}) |
|
if (st.session_state.lastQuestion =='yes'): |
|
switch_page('finalPage') |
|
else: |
|
st.session_state.profileIndex =st.session_state.profileIndices[0] |
|
switch_page(st.session_state.pages[st.session_state.nextPage2]) |
|
else: |
|
if st.button('Continue to evaluation'): |
|
st.session_state['user_active2']=True |
|
st.experimental_rerun() |
|
|