Spaces:
Runtime error
Runtime error
File size: 2,479 Bytes
a63b3d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import SGDClassifier
from sklearn.inspection import DecisionBoundaryDisplay
def predict_class(x, y):
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
colors = "bry"
idx = np.arange(X.shape[0])
np.random.seed(13)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std
clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)
predicted_class = clf.predict(np.array([[x, y]]))[0]
return iris.target_names[predicted_class]
def decision_boundary(x_min, x_max, y_min, y_max):
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
colors = "bry"
idx = np.arange(X.shape[0])
np.random.seed(13)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std
clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)
ax = plt.gca()
DecisionBoundaryDisplay.from_estimator(
clf,
X,
cmap=plt.cm.Paired,
ax=ax,
response_method="predict",
xlabel=iris.feature_names[0],
ylabel=iris.feature_names[1],
)
plt.axis([x_min, x_max, y_min, y_max])
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.gcf().set_size_inches(5, 4)
return plt.gcf()
iris = datasets.load_iris()
inputs = [
gr.inputs.Slider(0, 8, label=iris.feature_names[0], default=5.8, decimal=1),
gr.inputs.Slider(0, 8, label=iris.feature_names[1], default=3.5, decimal=1),
]
output = gr.outputs.Label(num_top_classes=1)
title = "Iris Dataset - Decision Boundary"
description = "Predict the class of the given data point and show the decision boundary of the SGD classifier."
article = "<p><a href='https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_iris.html'>More about the dataset and the example</a></p>"
examples = [
[
5.8,
3.5,
],
[
7.2,
3.2,
],
[
5.1,
2.5,
],
[
4.9,
3.1,
],
]
gr.Interface(
predict_class,
inputs,
output,
title=title,
description=description,
examples=examples,
theme=theme,
article=article,
layout="vertical",
allow_flagging=False,
live=True,
outputs=[None, decision_boundary],
).launch()
|