File size: 3,476 Bytes
a74c801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65cbdb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from itertools import combinations
from functools import partial

plt.rcParams['figure.dpi'] = 100

from sklearn.datasets import load_iris
from sklearn.ensemble import (
    RandomForestClassifier,
    ExtraTreesClassifier,
    AdaBoostClassifier,
)
from sklearn.tree import DecisionTreeClassifier

import gradio as gr

# ========================================

C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
CMAP = ListedColormap([C1, C2, C3])
GRANULARITY = 0.01
SEED = 1
N_ESTIMATORS = 30

FEATURES = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
LABELS = ["Setosa", "Versicolour", "Virginica"]
MODEL_NAMES = ['DecisionTreeClassifier', 'RandomForestClassifier', 'ExtraTreesClassifier', 'AdaBoostClassifier']

iris = load_iris()

MODELS = [
        DecisionTreeClassifier(max_depth=None),
        RandomForestClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
        ExtraTreesClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
        AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=N_ESTIMATORS)
        ]

# ========================================

def create_plot(feature_string, n_estimators, model_idx):
    np.random.seed(SEED)

    feature_list = feature_string.split(',')
    feature_list = [s.strip() for s in feature_list]
    idx_x = FEATURES.index(feature_list[0])
    idx_y = FEATURES.index(feature_list[1])

    X = iris.data[:, [idx_x, idx_y]]
    y = iris.target

    rnd_idx = np.random.permutation(X.shape[0])
    X = X[rnd_idx]
    y = y[rnd_idx]

    X = (X - X.mean(0)) / X.std(0)

    model_name = MODEL_NAMES[model_idx]
    model = MODELS[model_idx]
    
    model.fit(X, y)
    score = round(model.score(X, y), 3)

    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xrange = np.arange(x_min, x_max, 0.1)
    yrange = np.arange(y_min, y_max, 0.1)
    xx, yy = np.meshgrid(xrange, yrange)

    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)

    for i, label in enumerate(LABELS):
        X_label = X[y==i,:]
        y_label = y[y==i]
        ax.scatter(X_label[:, 0], X_label[:, 1], c=[[C1], [C2], [C3]][i]*len(y_label), edgecolor='k', s=40, label=label)

    ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
    ax.legend()
    ax.set_title(f'{model_name} | Score: {score}')

    return fig

def iter_grid(n_rows, n_cols):
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield

with gr.Blocks() as demo:
    selections = combinations(FEATURES, 2)
    selections = [f'{s[0]}, {s[1]}' for s in selections] 
    dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
    slider = gr.Slider(1, 100, value=30, step=1, label='n_estimators')

    counter = 0
    for _ in iter_grid(2, 2):
        if counter >= len(MODELS):
            break
        
        plot = gr.Plot(label=f'{MODEL_NAMES[counter]}')
        fn = partial(create_plot, model_idx=counter)

        dd.change(fn, inputs=[dd, slider], outputs=[plot])
        slider.change(fn, inputs=[dd, slider], outputs=[plot])
        demo.load(fn, inputs=[dd, slider], outputs=[plot])

        counter += 1

demo.launch()