huabdul's picture
Update app.py
4432d63
raw
history blame
4.52 kB
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from itertools import combinations
from functools import partial
plt.rcParams['figure.dpi'] = 100
from sklearn.datasets import load_iris
from sklearn.ensemble import (
RandomForestClassifier,
ExtraTreesClassifier,
AdaBoostClassifier,
)
from sklearn.tree import DecisionTreeClassifier
import gradio as gr
# ========================================
C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
CMAP = ListedColormap([C1, C2, C3])
GRANULARITY = 0.01
SEED = 1
N_ESTIMATORS = 30
FEATURES = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
LABELS = ["Setosa", "Versicolour", "Virginica"]
MODEL_NAMES = ['DecisionTreeClassifier', 'RandomForestClassifier', 'ExtraTreesClassifier', 'AdaBoostClassifier']
iris = load_iris()
MODELS = [
DecisionTreeClassifier(max_depth=None),
RandomForestClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
ExtraTreesClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=N_ESTIMATORS)
]
# ========================================
def create_plot(feature_string, n_estimators, max_depth, model_idx):
np.random.seed(SEED)
feature_list = feature_string.split(',')
feature_list = [s.strip() for s in feature_list]
idx_x = FEATURES.index(feature_list[0])
idx_y = FEATURES.index(feature_list[1])
X = iris.data[:, [idx_x, idx_y]]
y = iris.target
rnd_idx = np.random.permutation(X.shape[0])
X = X[rnd_idx]
y = y[rnd_idx]
X = (X - X.mean(0)) / X.std(0)
model_name = MODEL_NAMES[model_idx]
model = MODELS[model_idx]
if model_idx != 0: model.n_estimators = n_estimators
if model_idx != 3: model.max_depth = max_depth
if model_idx == 3: model.estimator.max_depth = max_depth
model.fit(X, y)
score = round(model.score(X, y), 3)
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xrange = np.arange(x_min, x_max, 0.1)
yrange = np.arange(y_min, y_max, 0.1)
xx, yy = np.meshgrid(xrange, yrange)
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
for i, label in enumerate(LABELS):
X_label = X[y==i,:]
y_label = y[y==i]
ax.scatter(X_label[:, 0], X_label[:, 1], c=[[C1], [C2], [C3]][i]*len(y_label), edgecolor='k', s=40, label=label)
ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
ax.legend()
ax.set_title(f'{model_name} | Score: {score}')
return fig
def iter_grid(n_rows, n_cols):
for _ in range(n_rows):
with gr.Row():
for _ in range(n_cols):
with gr.Column():
yield
info = '''
## Plot the decision surfaces of ensembles of trees on the Iris dataset
This plot compares the decision surfaces learned by a decision tree classifier, a random forest classifier, an extra-trees classifier, and by an AdaBoost classifier.
There are in total four features in the Iris dataset. In this example you can select two features at a time for visualization purposes using the dropdown box below.
You can also vary the number of estimators in the ensembles and the max depth of the trees using the sliders.
'''
with gr.Blocks() as demo:
gr.Markdown(info)
selections = combinations(FEATURES, 2)
selections = [f'{s[0]}, {s[1]}' for s in selections]
dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
slider_estimators = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
slider_max_depth = gr.Slider(1, 50, value=10, step=1, label='max_depth')
counter = 0
for _ in iter_grid(2, 2):
if counter >= len(MODELS):
break
plot = gr.Plot(label=f'{MODEL_NAMES[counter]}')
fn = partial(create_plot, model_idx=counter)
dd.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
slider_estimators.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
slider_max_depth.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
demo.load(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
counter += 1
demo.launch(share=True, debug=True)