boringnose commited on
Commit
773ae62
·
verified ·
1 Parent(s): 1fec2dd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +168 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import streamlit as st
4
+ import time
5
+
6
+ from sklearn import datasets
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.tree import DecisionTreeClassifier
9
+ from sklearn import tree
10
+ from sklearn.metrics import accuracy_score
11
+ import matplotlib.pyplot as plt
12
+
13
+ st.set_page_config(
14
+ page_title="Decision Tree Visualizer",
15
+ page_icon=":chart_with_upwards_trend:",
16
+ layout="wide",
17
+ initial_sidebar_state="expanded")
18
+
19
+ # load dataset
20
+ iris=datasets.load_iris()
21
+ x = iris.data
22
+ y = iris.target
23
+ x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=0.2,random_state=42)
24
+
25
+
26
+ # constants
27
+ min_weight_fraction_leaf=0.0
28
+ max_features = None
29
+ max_leaf_nodes = None
30
+ min_impurity_decrease=0.0
31
+
32
+
33
+
34
+
35
+ # Load initial graph
36
+ fig, ax = plt.subplots()
37
+
38
+ # Plot initial graph
39
+ scatter = ax.scatter(x.T[0], x.T[1], c=y, cmap='rainbow')
40
+ ax.set_xlabel(iris.feature_names[0], fontsize=10)
41
+ ax.set_ylabel(iris.feature_names[1],fontsize=10)
42
+ ax.set_title('Sepal Length vs Sepal Width', fontsize=15)
43
+ legend1 = ax.legend(*scatter.legend_elements(),
44
+ title="Classes",loc="upper right")
45
+ ax.add_artist(legend1)
46
+ ax.legend()
47
+ orig = st.pyplot(fig)
48
+
49
+ # sidebar elements
50
+ st.sidebar.header(':blue[_Decision Tree_] Algo Visualizer', divider='rainbow')
51
+
52
+ criterion = st.sidebar.selectbox("Criterion",
53
+ ("gini", "entropy", "log_loss"),
54
+ help="""The function to measure the quality of a split.
55
+ Supported criteria are “gini” for the Gini impurity and “log_loss” and “entropy”
56
+ both for the Shannon information gain""")
57
+ max_depth = st.sidebar.number_input("Max Depth",
58
+ min_value=0,
59
+ max_value=30,
60
+ step=1,
61
+ value=0,
62
+ help="""The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure""")
63
+ if max_depth == 0:
64
+ max_depth=None
65
+ min_samples_split = st.sidebar.number_input("Min Sample Split",
66
+ min_value=0,
67
+ max_value=x_train.shape[0],
68
+ value=2,
69
+ help="""The minimum number of samples required to split an internal node.
70
+ If float, enter between 0 and 1""")
71
+ min_samples_leaf = st.sidebar.number_input("Min sample Leaf",
72
+ min_value=0,
73
+ max_value=x_train.shape[0],
74
+ value=1,
75
+ help="""The minimum number of samples required to be at a leaf node.
76
+ If float, enter between 0 and 1""")
77
+ random_state = st.sidebar.number_input("Random State",
78
+ min_value=0,
79
+ value=42)
80
+
81
+ # advance features
82
+ toggle = st.sidebar.toggle("Advance Features")
83
+
84
+ if toggle:
85
+ min_weight_fraction_leaf = st.sidebar.number_input("Min Weight Fraction Leaf",
86
+ min_value=0.0,
87
+ max_value=1.0,
88
+ value=0.0,
89
+ help="""The minimum weighted fraction of the sum total of weights
90
+ (of all the input samples) required to be at a leaf node. """)
91
+ max_features = st.sidebar.selectbox("Max Features",
92
+ (None,"sqrt", "log2","Custom"),
93
+ help="""The number of features to consider when looking for the best split""")
94
+ if max_features == "Custom":
95
+ max_features = st.sidebar.number_input("Enter Max Features",
96
+ value=None,
97
+ step=1)
98
+
99
+ max_leaf_nodes = st.sidebar.number_input("Max Leaf Nodes",
100
+ min_value=0,
101
+ help="""Grow a tree with max_leaf_nodes in best-first fashion. """)
102
+ if max_leaf_nodes==0:
103
+ max_leaf_nodes=None
104
+ min_impurity_decrease = st.sidebar.number_input("Min Impurity Decrase",
105
+ min_value=0.0,
106
+ help="""A node will be split if this split induces a decrease of the
107
+ impurity greater than or equal to this value.""")
108
+ train = st.sidebar.button("Train Model", type="primary")
109
+ if st.sidebar.button("Reset"):
110
+ st.experimental_rerun()
111
+ if train:
112
+ orig.empty()
113
+
114
+ msg = st.toast('Running', icon='🫸🏼')
115
+ # building model
116
+ clf = DecisionTreeClassifier(criterion=criterion,max_depth=max_depth,
117
+ min_samples_split=min_samples_split,
118
+ min_samples_leaf=min_samples_leaf,
119
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
120
+ max_features=max_features,
121
+ random_state=random_state,
122
+ max_leaf_nodes=max_leaf_nodes,
123
+ min_impurity_decrease=min_impurity_decrease)
124
+ clf.fit(x_train[:, :2], y_train)
125
+ x_pred = clf.predict(x_train[:,:2])
126
+ y_pred = clf.predict(x_test[:, :2])
127
+ st.subheader("Train Accuracy " + str(round(accuracy_score(y_train, x_pred), 2)) + ", "+ "Test Accuracy " + str(round(accuracy_score(y_test, y_pred), 2)))
128
+ st.write("Total Depth: " + str(clf.tree_.max_depth))
129
+
130
+
131
+ # # define ranges for meshgrid
132
+ x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
133
+ y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
134
+
135
+ xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
136
+ np.arange(y_min, y_max, 0.01))
137
+
138
+ Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
139
+ Z = Z.reshape(xx.shape)
140
+
141
+ # Plot the decision boundaries
142
+ plt.figure(figsize=(8, 6))
143
+ plt.contourf(xx, yy, Z, alpha=0.8)
144
+ plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k', s=20)
145
+ plt.xlabel('Sepal length')
146
+ plt.ylabel('Sepal width')
147
+ plt.title('Decision Boundaries')
148
+ plt.tight_layout()
149
+ plt.savefig('decision_boundary_plot.png')
150
+ plt.close()
151
+
152
+ # Display decision boundary plot
153
+ st.image("decision_boundary_plot.png")
154
+
155
+ # Plot decision tree
156
+ plt.figure(figsize=(25, 20))
157
+ tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
158
+ plt.xlim(plt.xlim()[0] * 2, plt.xlim()[1] * 2)
159
+ plt.ylim(plt.ylim()[0] * 2, plt.ylim()[1] * 2)
160
+ plt.savefig("decision_tree.png")
161
+ plt.close()
162
+
163
+ # Display decision tree plot
164
+ st.image("decision_tree.png")
165
+
166
+ msg.toast('Model run successfully!', icon='😎')
167
+
168
+
requirements.txt ADDED
Binary file (188 Bytes). View file