File size: 7,562 Bytes
773ae62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcb291
773ae62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c1a6e2
 
 
 
 
 
 
 
773ae62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c1a6e2
 
773ae62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import pandas as pd
import numpy as np
import streamlit as st
import time

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

st.set_page_config(
    page_title="Decision Tree Visualizer",
    page_icon=":chart_with_upwards_trend:",
    layout="wide",
    initial_sidebar_state="expanded")

# load dataset
iris=datasets.load_iris()
x = iris.data
y = iris.target
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=0.2,random_state=42)


# constants
min_weight_fraction_leaf=0.0
max_features = None
max_leaf_nodes = None
min_impurity_decrease=0.0
ccp_alpha = 0.0



# Load initial graph
fig, ax = plt.subplots()

# Plot initial graph
scatter = ax.scatter(x.T[0], x.T[1], c=y, cmap='rainbow')
ax.set_xlabel(iris.feature_names[0], fontsize=10)
ax.set_ylabel(iris.feature_names[1],fontsize=10)
ax.set_title('Sepal Length vs Sepal Width', fontsize=15)
legend1 = ax.legend(*scatter.legend_elements(),
                     title="Classes",loc="upper right")
ax.add_artist(legend1)
ax.legend()
orig = st.pyplot(fig)

# sidebar elements 
st.sidebar.header(':blue[_Decision Tree_] Algo Visualizer', divider='rainbow')

criterion = st.sidebar.selectbox("Criterion",
                                ("gini", "entropy", "log_loss"),
                                help="""The function to measure the quality of a split.
                                Supported criteria are “gini” for the Gini impurity and “log_loss” and “entropy” 
                                both for the Shannon information gain""")
max_depth = st.sidebar.number_input("Max Depth",
                            min_value=0,
                            max_value=30,
                            step=1,
                            value=0,
                            help="""The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure""")
if max_depth == 0:
    max_depth=None
min_samples_split = st.sidebar.number_input("Min Sample Split",
                                   min_value=0,
                                    max_value=x_train.shape[0],
                                   value=2,
                                   help="""The minimum number of samples required to split an internal node.
                                   If float, enter between 0 and 1""")
min_samples_leaf = st.sidebar.number_input("Min sample Leaf",
                                  min_value=0,
                                    max_value=x_train.shape[0],
                                  value=1,
                                  help="""The minimum number of samples required to be at a leaf node. 
                                  If float, enter between 0 and 1""")
random_state = st.sidebar.number_input("Random State",
                               min_value=0,
                              value=42)

# advance features
toggle = st.sidebar.toggle("Advance Features")

if toggle:
    min_weight_fraction_leaf = st.sidebar.number_input("Min Weight Fraction Leaf",
                                              min_value=0.0,
                                              max_value=1.0,
                                              value=0.0,
                                              help="""The minimum weighted fraction of the sum total of weights 
                                              (of all the input samples) required to be at a leaf node. """)
    max_features = st.sidebar.selectbox("Max Features",
                                       (None,"sqrt", "log2","Custom"),
                                       help="""The number of features to consider when looking for the best split""")
    if max_features == "Custom":
        max_features = st.sidebar.number_input("Enter Max Features",
                                      value=None,
                                      step=1)
    
    max_leaf_nodes = st.sidebar.number_input("Max Leaf Nodes",
                                            min_value=0,
                                            help="""Grow a tree with max_leaf_nodes in best-first fashion. """)
    if max_leaf_nodes==0:
        max_leaf_nodes=None
    min_impurity_decrease = st.sidebar.number_input("Min Impurity Decrase",
                                                   min_value=0.0,
                                                   help="""A node will be split if this split induces a decrease of the 
                                                   impurity greater than or equal to this value.""")
    ccp_alpha = st.sidebar.number_input("ccp_alpha",
                            min_value=0,
                            max_value=30,
                            step=0.1,
                            value=0,
                            help="""Complexity parameter used for Minimal Cost-Complexity Pruning. 
                            The subtree with the largest cost complexity that is smaller than ccp_alpha will be chosen. 
                            By default, no pruning is performed.""")
train = st.sidebar.button("Train Model", type="primary")
if st.sidebar.button("Reset"):
    st.experimental_rerun()
if train:
    orig.empty()

    msg = st.toast('Running', icon='🫸🏼')
    # building model
    clf = DecisionTreeClassifier(criterion=criterion,max_depth=max_depth,
                                min_samples_split=min_samples_split,
                                min_samples_leaf=min_samples_leaf,
                                min_weight_fraction_leaf=min_weight_fraction_leaf,
                                max_features=max_features,
                                random_state=random_state,
                                max_leaf_nodes=max_leaf_nodes,
                                min_impurity_decrease=min_impurity_decrease,
                                ccp_alpha = ccp_alpha)
    clf.fit(x_train[:, :2], y_train)
    x_pred = clf.predict(x_train[:,:2])
    y_pred = clf.predict(x_test[:, :2])
    st.subheader("Train Accuracy " + str(round(accuracy_score(y_train, x_pred), 2)) + ",  "+ "Test Accuracy  " + str(round(accuracy_score(y_test, y_pred), 2)))
    st.write("Total Depth:  " + str(clf.tree_.max_depth))
    

    # # define ranges for meshgrid
    x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
    y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
    
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                         np.arange(y_min, y_max, 0.01))
    
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    # Plot the decision boundaries
    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, Z, alpha=0.8)
    plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k', s=20)
    plt.xlabel('Sepal length')
    plt.ylabel('Sepal width')
    plt.title('Decision Boundaries')
    plt.tight_layout()
    plt.savefig('decision_boundary_plot.png')
    plt.close()
    
    # Display decision boundary plot
    st.image("decision_boundary_plot.png")
    
    # Plot decision tree
    plt.figure(figsize=(25, 20))
    tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
    plt.xlim(plt.xlim()[0] * 2, plt.xlim()[1] * 2)
    plt.ylim(plt.ylim()[0] * 2, plt.ylim()[1] * 2)
    plt.savefig("decision_tree.png")
    plt.close()
    
    # Display decision tree plot
    st.image("decision_tree.png")

    msg.toast('Model run successfully!', icon='😎')