File size: 2,479 Bytes
a63b3d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import SGDClassifier
from sklearn.inspection import DecisionBoundaryDisplay

def predict_class(x, y):
    iris = datasets.load_iris()
    X = iris.data[:, :2]
    y = iris.target
    colors = "bry"
    idx = np.arange(X.shape[0])
    np.random.seed(13)
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]
    mean = X.mean(axis=0)
    std = X.std(axis=0)
    X = (X - mean) / std
    clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)
    predicted_class = clf.predict(np.array([[x, y]]))[0]
    return iris.target_names[predicted_class]

def decision_boundary(x_min, x_max, y_min, y_max):
    iris = datasets.load_iris()
    X = iris.data[:, :2]
    y = iris.target
    colors = "bry"
    idx = np.arange(X.shape[0])
    np.random.seed(13)
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]
    mean = X.mean(axis=0)
    std = X.std(axis=0)
    X = (X - mean) / std
    clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)
    ax = plt.gca()
    DecisionBoundaryDisplay.from_estimator(
        clf,
        X,
        cmap=plt.cm.Paired,
        ax=ax,
        response_method="predict",
        xlabel=iris.feature_names[0],
        ylabel=iris.feature_names[1],
    )
    plt.axis([x_min, x_max, y_min, y_max])
    plt.xticks(fontsize=8)
    plt.yticks(fontsize=8)
    plt.gcf().set_size_inches(5, 4)
    return plt.gcf()

iris = datasets.load_iris()

inputs = [
    gr.inputs.Slider(0, 8, label=iris.feature_names[0], default=5.8, decimal=1),
    gr.inputs.Slider(0, 8, label=iris.feature_names[1], default=3.5, decimal=1),
]

output = gr.outputs.Label(num_top_classes=1)

title = "Iris Dataset - Decision Boundary"
description = "Predict the class of the given data point and show the decision boundary of the SGD classifier."
article = "<p><a href='https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_iris.html'>More about the dataset and the example</a></p>"
examples = [
    [
        5.8,
        3.5,
    ],
    [
        7.2,
        3.2,
    ],
    [
        5.1,
        2.5,
    ],
    [
        4.9,
        3.1,
    ],
]

gr.Interface(
    predict_class,
    inputs,
    output,
    title=title,
    description=description,
    examples=examples,
    theme=theme,
    article=article,
    layout="vertical",
    allow_flagging=False,
    live=True,
    outputs=[None, decision_boundary],
).launch()