Spaces:
Runtime error
Runtime error
File size: 7,562 Bytes
773ae62 fbcb291 773ae62 6c1a6e2 773ae62 6c1a6e2 773ae62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import pandas as pd
import numpy as np
import streamlit as st
import time
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
st.set_page_config(
page_title="Decision Tree Visualizer",
page_icon=":chart_with_upwards_trend:",
layout="wide",
initial_sidebar_state="expanded")
# load dataset
iris=datasets.load_iris()
x = iris.data
y = iris.target
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=0.2,random_state=42)
# constants
min_weight_fraction_leaf=0.0
max_features = None
max_leaf_nodes = None
min_impurity_decrease=0.0
ccp_alpha = 0.0
# Load initial graph
fig, ax = plt.subplots()
# Plot initial graph
scatter = ax.scatter(x.T[0], x.T[1], c=y, cmap='rainbow')
ax.set_xlabel(iris.feature_names[0], fontsize=10)
ax.set_ylabel(iris.feature_names[1],fontsize=10)
ax.set_title('Sepal Length vs Sepal Width', fontsize=15)
legend1 = ax.legend(*scatter.legend_elements(),
title="Classes",loc="upper right")
ax.add_artist(legend1)
ax.legend()
orig = st.pyplot(fig)
# sidebar elements
st.sidebar.header(':blue[_Decision Tree_] Algo Visualizer', divider='rainbow')
criterion = st.sidebar.selectbox("Criterion",
("gini", "entropy", "log_loss"),
help="""The function to measure the quality of a split.
Supported criteria are “gini” for the Gini impurity and “log_loss” and “entropy”
both for the Shannon information gain""")
max_depth = st.sidebar.number_input("Max Depth",
min_value=0,
max_value=30,
step=1,
value=0,
help="""The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure""")
if max_depth == 0:
max_depth=None
min_samples_split = st.sidebar.number_input("Min Sample Split",
min_value=0,
max_value=x_train.shape[0],
value=2,
help="""The minimum number of samples required to split an internal node.
If float, enter between 0 and 1""")
min_samples_leaf = st.sidebar.number_input("Min sample Leaf",
min_value=0,
max_value=x_train.shape[0],
value=1,
help="""The minimum number of samples required to be at a leaf node.
If float, enter between 0 and 1""")
random_state = st.sidebar.number_input("Random State",
min_value=0,
value=42)
# advance features
toggle = st.sidebar.toggle("Advance Features")
if toggle:
min_weight_fraction_leaf = st.sidebar.number_input("Min Weight Fraction Leaf",
min_value=0.0,
max_value=1.0,
value=0.0,
help="""The minimum weighted fraction of the sum total of weights
(of all the input samples) required to be at a leaf node. """)
max_features = st.sidebar.selectbox("Max Features",
(None,"sqrt", "log2","Custom"),
help="""The number of features to consider when looking for the best split""")
if max_features == "Custom":
max_features = st.sidebar.number_input("Enter Max Features",
value=None,
step=1)
max_leaf_nodes = st.sidebar.number_input("Max Leaf Nodes",
min_value=0,
help="""Grow a tree with max_leaf_nodes in best-first fashion. """)
if max_leaf_nodes==0:
max_leaf_nodes=None
min_impurity_decrease = st.sidebar.number_input("Min Impurity Decrase",
min_value=0.0,
help="""A node will be split if this split induces a decrease of the
impurity greater than or equal to this value.""")
ccp_alpha = st.sidebar.number_input("ccp_alpha",
min_value=0,
max_value=30,
step=0.1,
value=0,
help="""Complexity parameter used for Minimal Cost-Complexity Pruning.
The subtree with the largest cost complexity that is smaller than ccp_alpha will be chosen.
By default, no pruning is performed.""")
train = st.sidebar.button("Train Model", type="primary")
if st.sidebar.button("Reset"):
st.experimental_rerun()
if train:
orig.empty()
msg = st.toast('Running', icon='🫸🏼')
# building model
clf = DecisionTreeClassifier(criterion=criterion,max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
random_state=random_state,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
ccp_alpha = ccp_alpha)
clf.fit(x_train[:, :2], y_train)
x_pred = clf.predict(x_train[:,:2])
y_pred = clf.predict(x_test[:, :2])
st.subheader("Train Accuracy " + str(round(accuracy_score(y_train, x_pred), 2)) + ", "+ "Test Accuracy " + str(round(accuracy_score(y_test, y_pred), 2)))
st.write("Total Depth: " + str(clf.tree_.max_depth))
# # define ranges for meshgrid
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
np.arange(y_min, y_max, 0.01))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# Plot the decision boundaries
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, alpha=0.8)
plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k', s=20)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('Decision Boundaries')
plt.tight_layout()
plt.savefig('decision_boundary_plot.png')
plt.close()
# Display decision boundary plot
st.image("decision_boundary_plot.png")
# Plot decision tree
plt.figure(figsize=(25, 20))
tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.xlim(plt.xlim()[0] * 2, plt.xlim()[1] * 2)
plt.ylim(plt.ylim()[0] * 2, plt.ylim()[1] * 2)
plt.savefig("decision_tree.png")
plt.close()
# Display decision tree plot
st.image("decision_tree.png")
msg.toast('Model run successfully!', icon='😎')
|