Upload 15 files
Browse files- app.py +37 -0
- data/iris.csv +151 -0
- data/model.pkl +3 -0
- images/feature01.png +0 -0
- images/setosa.jpg +0 -0
- images/setosa.webp +0 -0
- images/versicolor.jpg +0 -0
- images/versicolor.webp +0 -0
- images/virginica.jpg +0 -0
- images/virginica.webp +0 -0
- images//344/270/213/350/275/275 (1).jpeg +0 -0
- pages/001 data_intro.py +33 -0
- pages/002 data_feature.py +56 -0
- pages/003 model train.py +124 -0
- pages/004 model_sample.py +59 -0
app.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Introduction
|
2 |
+
import simplestart as ss
|
3 |
+
|
4 |
+
ss.md('''
|
5 |
+
## Iris Dataset
|
6 |
+
The Iris dataset is a commonly used classification dataset, compiled by Fisher in 1936. Also known as the Iris flower dataset, it is a multivariate analysis dataset. The dataset contains 150 samples divided into 3 classes, with 50 samples each, and each sample has 4 attributes. The attributes—sepal length, sepal width, petal length, and petal width—can be used to predict which of the three species (Setosa, Versicolor, Virginica) an Iris flower belongs to.
|
7 |
+
|
8 |
+
### Iris Flower
|
9 |
+
Iris flowers have a rich cultural background. They are named for their petals, which resemble the tails of birds. The Latin genus name "iris" means "rainbow" in Greek, symbolizing the variety of flower colors.
|
10 |
+
''')
|
11 |
+
|
12 |
+
|
13 |
+
with ss.row(style="margin:10px 0"):
|
14 |
+
with ss.col():
|
15 |
+
ss.image("./images/setosa.webp", title = "Silky Iris Setosa", elevation = 10, width=250)
|
16 |
+
|
17 |
+
with ss.col():
|
18 |
+
ss.image("./images/versicolor.webp", title = "Iris Versicolor", elevation = 10, width=250)
|
19 |
+
|
20 |
+
with ss.col():
|
21 |
+
ss.image("./images/virginica.webp", title = "Virginia Iris Virginica", elevation = 10, width=250)
|
22 |
+
|
23 |
+
ss.md('''
|
24 |
+
### Machine Learning
|
25 |
+
|
26 |
+
This tutorial will use the scikit-learn library to build a machine learning classification model to predict the species of Iris flowers. Specifically, we will train and test the model using measurement data from the Iris flowers—including petal and sepal lengths and widths. Our goal is to teach the model how to learn from these labeled data through the application of several classic machine learning algorithms, enabling accurate species predictions for new Iris flowers.
|
27 |
+
|
28 |
+
''')
|
29 |
+
|
30 |
+
ss.md('''
|
31 |
+
###
|
32 |
+
References for this example include:
|
33 |
+
---
|
34 |
+
[1. Basic Machine Learning: 1.7 Iris Classification](https://blog.csdn.net/qq_47809408/article/details/124632290)
|
35 |
+
[2. Introduction to KNN Classification Algorithm: Classifying the Iris Dataset](https://blog.csdn.net/weixin_51756038/article/details/130096706)
|
36 |
+
[3. Interactive Web App with Streamlit and Scikit-learn](https://github.com/patrickloeber/streamlit-demo)
|
37 |
+
''')
|
data/iris.csv
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sepal_length,sepal_width,petal_length,petal_width,species
|
2 |
+
5.1,3.5,1.4,0.2,setosa
|
3 |
+
4.9,3.0,1.4,0.2,setosa
|
4 |
+
4.7,3.2,1.3,0.2,setosa
|
5 |
+
4.6,3.1,1.5,0.2,setosa
|
6 |
+
5.0,3.6,1.4,0.2,setosa
|
7 |
+
5.4,3.9,1.7,0.4,setosa
|
8 |
+
4.6,3.4,1.4,0.3,setosa
|
9 |
+
5.0,3.4,1.5,0.2,setosa
|
10 |
+
4.4,2.9,1.4,0.2,setosa
|
11 |
+
4.9,3.1,1.5,0.1,setosa
|
12 |
+
5.4,3.7,1.5,0.2,setosa
|
13 |
+
4.8,3.4,1.6,0.2,setosa
|
14 |
+
4.8,3.0,1.4,0.1,setosa
|
15 |
+
4.3,3.0,1.1,0.1,setosa
|
16 |
+
5.8,4.0,1.2,0.2,setosa
|
17 |
+
5.7,4.4,1.5,0.4,setosa
|
18 |
+
5.4,3.9,1.3,0.4,setosa
|
19 |
+
5.1,3.5,1.4,0.3,setosa
|
20 |
+
5.7,3.8,1.7,0.3,setosa
|
21 |
+
5.1,3.8,1.5,0.3,setosa
|
22 |
+
5.4,3.4,1.7,0.2,setosa
|
23 |
+
5.1,3.7,1.5,0.4,setosa
|
24 |
+
4.6,3.6,1.0,0.2,setosa
|
25 |
+
5.1,3.3,1.7,0.5,setosa
|
26 |
+
4.8,3.4,1.9,0.2,setosa
|
27 |
+
5.0,3.0,1.6,0.2,setosa
|
28 |
+
5.0,3.4,1.6,0.4,setosa
|
29 |
+
5.2,3.5,1.5,0.2,setosa
|
30 |
+
5.2,3.4,1.4,0.2,setosa
|
31 |
+
4.7,3.2,1.6,0.2,setosa
|
32 |
+
4.8,3.1,1.6,0.2,setosa
|
33 |
+
5.4,3.4,1.5,0.4,setosa
|
34 |
+
5.2,4.1,1.5,0.1,setosa
|
35 |
+
5.5,4.2,1.4,0.2,setosa
|
36 |
+
4.9,3.1,1.5,0.2,setosa
|
37 |
+
5.0,3.2,1.2,0.2,setosa
|
38 |
+
5.5,3.5,1.3,0.2,setosa
|
39 |
+
4.9,3.6,1.4,0.1,setosa
|
40 |
+
4.4,3.0,1.3,0.2,setosa
|
41 |
+
5.1,3.4,1.5,0.2,setosa
|
42 |
+
5.0,3.5,1.3,0.3,setosa
|
43 |
+
4.5,2.3,1.3,0.3,setosa
|
44 |
+
4.4,3.2,1.3,0.2,setosa
|
45 |
+
5.0,3.5,1.6,0.6,setosa
|
46 |
+
5.1,3.8,1.9,0.4,setosa
|
47 |
+
4.8,3.0,1.4,0.3,setosa
|
48 |
+
5.1,3.8,1.6,0.2,setosa
|
49 |
+
4.6,3.2,1.4,0.2,setosa
|
50 |
+
5.3,3.7,1.5,0.2,setosa
|
51 |
+
5.0,3.3,1.4,0.2,setosa
|
52 |
+
7.0,3.2,4.7,1.4,versicolor
|
53 |
+
6.4,3.2,4.5,1.5,versicolor
|
54 |
+
6.9,3.1,4.9,1.5,versicolor
|
55 |
+
5.5,2.3,4.0,1.3,versicolor
|
56 |
+
6.5,2.8,4.6,1.5,versicolor
|
57 |
+
5.7,2.8,4.5,1.3,versicolor
|
58 |
+
6.3,3.3,4.7,1.6,versicolor
|
59 |
+
4.9,2.4,3.3,1.0,versicolor
|
60 |
+
6.6,2.9,4.6,1.3,versicolor
|
61 |
+
5.2,2.7,3.9,1.4,versicolor
|
62 |
+
5.0,2.0,3.5,1.0,versicolor
|
63 |
+
5.9,3.0,4.2,1.5,versicolor
|
64 |
+
6.0,2.2,4.0,1.0,versicolor
|
65 |
+
6.1,2.9,4.7,1.4,versicolor
|
66 |
+
5.6,2.9,3.6,1.3,versicolor
|
67 |
+
6.7,3.1,4.4,1.4,versicolor
|
68 |
+
5.6,3.0,4.5,1.5,versicolor
|
69 |
+
5.8,2.7,4.1,1.0,versicolor
|
70 |
+
6.2,2.2,4.5,1.5,versicolor
|
71 |
+
5.6,2.5,3.9,1.1,versicolor
|
72 |
+
5.9,3.2,4.8,1.8,versicolor
|
73 |
+
6.1,2.8,4.0,1.3,versicolor
|
74 |
+
6.3,2.5,4.9,1.5,versicolor
|
75 |
+
6.1,2.8,4.7,1.2,versicolor
|
76 |
+
6.4,2.9,4.3,1.3,versicolor
|
77 |
+
6.6,3.0,4.4,1.4,versicolor
|
78 |
+
6.8,2.8,4.8,1.4,versicolor
|
79 |
+
6.7,3.0,5.0,1.7,versicolor
|
80 |
+
6.0,2.9,4.5,1.5,versicolor
|
81 |
+
5.7,2.6,3.5,1.0,versicolor
|
82 |
+
5.5,2.4,3.8,1.1,versicolor
|
83 |
+
5.5,2.4,3.7,1.0,versicolor
|
84 |
+
5.8,2.7,3.9,1.2,versicolor
|
85 |
+
6.0,2.7,5.1,1.6,versicolor
|
86 |
+
5.4,3.0,4.5,1.5,versicolor
|
87 |
+
6.0,3.4,4.5,1.6,versicolor
|
88 |
+
6.7,3.1,4.7,1.5,versicolor
|
89 |
+
6.3,2.3,4.4,1.3,versicolor
|
90 |
+
5.6,3.0,4.1,1.3,versicolor
|
91 |
+
5.5,2.5,4.0,1.3,versicolor
|
92 |
+
5.5,2.6,4.4,1.2,versicolor
|
93 |
+
6.1,3.0,4.6,1.4,versicolor
|
94 |
+
5.8,2.6,4.0,1.2,versicolor
|
95 |
+
5.0,2.3,3.3,1.0,versicolor
|
96 |
+
5.6,2.7,4.2,1.3,versicolor
|
97 |
+
5.7,3.0,4.2,1.2,versicolor
|
98 |
+
5.7,2.9,4.2,1.3,versicolor
|
99 |
+
6.2,2.9,4.3,1.3,versicolor
|
100 |
+
5.1,2.5,3.0,1.1,versicolor
|
101 |
+
5.7,2.8,4.1,1.3,versicolor
|
102 |
+
6.3,3.3,6.0,2.5,virginica
|
103 |
+
5.8,2.7,5.1,1.9,virginica
|
104 |
+
7.1,3.0,5.9,2.1,virginica
|
105 |
+
6.3,2.9,5.6,1.8,virginica
|
106 |
+
6.5,3.0,5.8,2.2,virginica
|
107 |
+
7.6,3.0,6.6,2.1,virginica
|
108 |
+
4.9,2.5,4.5,1.7,virginica
|
109 |
+
7.3,2.9,6.3,1.8,virginica
|
110 |
+
6.7,2.5,5.8,1.8,virginica
|
111 |
+
7.2,3.6,6.1,2.5,virginica
|
112 |
+
6.5,3.2,5.1,2.0,virginica
|
113 |
+
6.4,2.7,5.3,1.9,virginica
|
114 |
+
6.8,3.0,5.5,2.1,virginica
|
115 |
+
5.7,2.5,5.0,2.0,virginica
|
116 |
+
5.8,2.8,5.1,2.4,virginica
|
117 |
+
6.4,3.2,5.3,2.3,virginica
|
118 |
+
6.5,3.0,5.5,1.8,virginica
|
119 |
+
7.7,3.8,6.7,2.2,virginica
|
120 |
+
7.7,2.6,6.9,2.3,virginica
|
121 |
+
6.0,2.2,5.0,1.5,virginica
|
122 |
+
6.9,3.2,5.7,2.3,virginica
|
123 |
+
5.6,2.8,4.9,2.0,virginica
|
124 |
+
7.7,2.8,6.7,2.0,virginica
|
125 |
+
6.3,2.7,4.9,1.8,virginica
|
126 |
+
6.7,3.3,5.7,2.1,virginica
|
127 |
+
7.2,3.2,6.0,1.8,virginica
|
128 |
+
6.2,2.8,4.8,1.8,virginica
|
129 |
+
6.1,3.0,4.9,1.8,virginica
|
130 |
+
6.4,2.8,5.6,2.1,virginica
|
131 |
+
7.2,3.0,5.8,1.6,virginica
|
132 |
+
7.4,2.8,6.1,1.9,virginica
|
133 |
+
7.9,3.8,6.4,2.0,virginica
|
134 |
+
6.4,2.8,5.6,2.2,virginica
|
135 |
+
6.3,2.8,5.1,1.5,virginica
|
136 |
+
6.1,2.6,5.6,1.4,virginica
|
137 |
+
7.7,3.0,6.1,2.3,virginica
|
138 |
+
6.3,3.4,5.6,2.4,virginica
|
139 |
+
6.4,3.1,5.5,1.8,virginica
|
140 |
+
6.0,3.0,4.8,1.8,virginica
|
141 |
+
6.9,3.1,5.4,2.1,virginica
|
142 |
+
6.7,3.1,5.6,2.4,virginica
|
143 |
+
6.9,3.1,5.1,2.3,virginica
|
144 |
+
5.8,2.7,5.1,1.9,virginica
|
145 |
+
6.8,3.2,5.9,2.3,virginica
|
146 |
+
6.7,3.3,5.7,2.5,virginica
|
147 |
+
6.7,3.0,5.2,2.3,virginica
|
148 |
+
6.3,2.5,5.0,1.9,virginica
|
149 |
+
6.5,3.0,5.2,2.0,virginica
|
150 |
+
6.2,3.4,5.4,2.3,virginica
|
151 |
+
5.9,3.0,5.1,1.8,virginica
|
data/model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e37d94112cf51c382e9639ba7d9aab490170f2aa684455aa589c0c097dbdb929
|
3 |
+
size 912
|
images/feature01.png
ADDED
![]() |
images/setosa.jpg
ADDED
![]() |
images/setosa.webp
ADDED
![]() |
images/versicolor.jpg
ADDED
![]() |
images/versicolor.webp
ADDED
![]() |
images/virginica.jpg
ADDED
![]() |
images/virginica.webp
ADDED
![]() |
images//344/270/213/350/275/275 (1).jpeg
ADDED
![]() |
pages/001 data_intro.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Data Exploration
|
2 |
+
import simplestart as ss
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
ss.md('''
|
6 |
+
## Iris Dataset
|
7 |
+
The dataset contains 150 samples divided into 3 classes: Setosa, Versicolor, and Virginica. Each class has 50 samples, and each sample includes 4 attributes.
|
8 |
+
''')
|
9 |
+
|
10 |
+
ss.space()
|
11 |
+
|
12 |
+
title = "Table 1. Iris Dataset"
|
13 |
+
subtitle = "sepal_length: length of the sepal, sepal_width: width of the sepal, petal_length: length of the petal, petal_width: width of the petal"
|
14 |
+
# Set global float display precision
|
15 |
+
pd.options.display.float_format = '{:.2f}'.format
|
16 |
+
df = pd.read_csv("./data/iris.csv")
|
17 |
+
|
18 |
+
ss.table(df, index=True, title=title, subtitle=subtitle, width=400)
|
19 |
+
|
20 |
+
ss.table(df.describe(), index=True)
|
21 |
+
|
22 |
+
ss.md("---")
|
23 |
+
# Simulated Data
|
24 |
+
import numpy as np
|
25 |
+
# Set random seed for reproducibility
|
26 |
+
np.random.seed(0)
|
27 |
+
|
28 |
+
num_rows = 10000
|
29 |
+
data = {
|
30 |
+
'Column1': np.random.randint(0, 100, size=num_rows), # Random integers
|
31 |
+
'Column2': np.random.random(size=num_rows), # Random floats
|
32 |
+
'Column3': np.random.choice(['A', 'B', 'C', 'D'], size=num_rows), # Randomly chosen categories
|
33 |
+
}
|
pages/002 data_feature.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Feature Analysis
|
2 |
+
|
3 |
+
import simplestart as ss
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
ss.md('''
|
7 |
+
## Feature Analysis
|
8 |
+
''')
|
9 |
+
|
10 |
+
ss.space()
|
11 |
+
|
12 |
+
ss.md("#### 1. Scatter Matrix of Features")
|
13 |
+
ss.space()
|
14 |
+
ss.image("./images/feature01.png", width=600, height=500)
|
15 |
+
|
16 |
+
ss.space()
|
17 |
+
|
18 |
+
ss.md('''
|
19 |
+
This image is from:
|
20 |
+
[VuNus 【Basics of Machine Learning】1.7 Iris Flower Classification](https://blog.csdn.net/qq_47809408/article/details/124632290)
|
21 |
+
''')
|
22 |
+
|
23 |
+
ss.space()
|
24 |
+
ss.md("#### 2. Feature Exploration")
|
25 |
+
import pandas as pd
|
26 |
+
from bokeh.plotting import figure, show
|
27 |
+
from bokeh.models import ColumnDataSource
|
28 |
+
from bokeh.transform import factor_cmap
|
29 |
+
from bokeh.embed import file_html
|
30 |
+
from bokeh.resources import CDN
|
31 |
+
from bokeh.palettes import Category10
|
32 |
+
|
33 |
+
# Load dataset
|
34 |
+
data = pd.read_csv("./data/iris.csv")
|
35 |
+
|
36 |
+
# Create Bokeh chart
|
37 |
+
p = figure(title="Iris Dataset Scatter Plot", x_axis_label='Petal Length (cm)', y_axis_label='Petal Width (cm)',
|
38 |
+
tools="pan,wheel_zoom,box_zoom,reset,hover,save", width=800, height=600)
|
39 |
+
|
40 |
+
# Create data source
|
41 |
+
source = ColumnDataSource(data)
|
42 |
+
|
43 |
+
# Set color mapping for species column
|
44 |
+
species_list = data['species'].unique().tolist()
|
45 |
+
p.circle(x='petal_length', y='petal_width', source=source, size=10,
|
46 |
+
color=factor_cmap('species', palette=Category10[3], factors=species_list), legend_field='species')
|
47 |
+
|
48 |
+
# Configure legend
|
49 |
+
p.legend.title = "Species"
|
50 |
+
p.legend.location = "top_left"
|
51 |
+
|
52 |
+
# Convert Bokeh chart to HTML and display
|
53 |
+
html_output = file_html(p, CDN, "Iris Dataset Scatter Plot")
|
54 |
+
# show(p)
|
55 |
+
|
56 |
+
ss.htmlview(html_output)
|
pages/003 model train.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Model Training
|
2 |
+
|
3 |
+
import simplestart as ss
|
4 |
+
|
5 |
+
from sklearn import datasets
|
6 |
+
from sklearn.neighbors import KNeighborsClassifier
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from sklearn.metrics import accuracy_score
|
9 |
+
|
10 |
+
|
11 |
+
ss.md('''
|
12 |
+
## Model Training
|
13 |
+
''')
|
14 |
+
|
15 |
+
# Load data and split samples
|
16 |
+
data = datasets.load_iris()
|
17 |
+
X = data.data
|
18 |
+
y = data.target
|
19 |
+
|
20 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
21 |
+
|
22 |
+
# Page session variables
|
23 |
+
ss.session["acc"] = ""
|
24 |
+
ss.session["code"] = 0
|
25 |
+
|
26 |
+
# Training function
|
27 |
+
def train(event):
|
28 |
+
clf = KNeighborsClassifier(n_neighbors=3)
|
29 |
+
clf.fit(X_train, y_train)
|
30 |
+
|
31 |
+
y_pred = clf.predict(X_test)
|
32 |
+
|
33 |
+
acc = accuracy_score(y_test, y_pred)
|
34 |
+
acc = round(acc, 2)
|
35 |
+
|
36 |
+
ss.session["acc"] = acc # Assign the result to page session variable, the corresponding page display will respond automatically
|
37 |
+
|
38 |
+
ss.md('''
|
39 |
+
#### Main Steps of Model Training:
|
40 |
+
First, the Iris dataset is loaded (including features and labels), and this data is split into training and testing sets, with 80% used for training and 20% for testing. Then, a training function is defined that uses the K-Nearest Neighbors (KNN) classifier to train the model, evaluate the predictive accuracy, and save the result in a page session variable for display on the webpage.
|
41 |
+
###
|
42 |
+
On the webpage, there is a training button. When the user clicks this button, the training function is triggered, the model is trained in the background, and the predictive accuracy on the test set is calculated. Once training is complete, the accuracy result is updated on the page and displayed to the user in the format "Accuracy = @acc," where @acc is the predictive accuracy value calculated during training.
|
43 |
+
###
|
44 |
+
The speed of training and testing is particularly fast because the Iris dataset is very small, containing only 150 samples and 4 features. Additionally, K-Nearest Neighbors (KNN) is a simple and efficient algorithm, especially effective on small datasets, thus the training and testing processes are completed quickly.
|
45 |
+
###
|
46 |
+
---
|
47 |
+
''')
|
48 |
+
ss.write(f'Test set predictive accuracy Accuracy =', "@acc")
|
49 |
+
|
50 |
+
ss.button("Train", onclick=train)
|
51 |
+
# UI
|
52 |
+
|
53 |
+
ss.md("---")
|
54 |
+
|
55 |
+
def conditioner(event):
|
56 |
+
return ss.session["code"] == 1
|
57 |
+
|
58 |
+
def checkcode(event):
|
59 |
+
ss.session["code"] = 1
|
60 |
+
|
61 |
+
def hidecode(event):
|
62 |
+
ss.session["code"] = 0
|
63 |
+
|
64 |
+
ss.button("View Code", onclick=checkcode)
|
65 |
+
ss.button("Hide Code", onclick=hidecode)
|
66 |
+
|
67 |
+
with ss.when(conditioner):
|
68 |
+
ss.md('''
|
69 |
+
```python
|
70 |
+
import simplestart as ss
|
71 |
+
|
72 |
+
from sklearn import datasets
|
73 |
+
from sklearn.neighbors import KNeighborsClassifier
|
74 |
+
from sklearn.model_selection import train_test_split
|
75 |
+
from sklearn.metrics import accuracy_score
|
76 |
+
|
77 |
+
# Load data and split samples
|
78 |
+
data = datasets.load_iris()
|
79 |
+
X = data.data
|
80 |
+
y = data.target
|
81 |
+
ss.write(X.shape, y.shape)
|
82 |
+
|
83 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
84 |
+
|
85 |
+
# Page session variable
|
86 |
+
ss.session["acc"] = ""
|
87 |
+
|
88 |
+
# Training function
|
89 |
+
def train(event):
|
90 |
+
clf = KNeighborsClassifier(n_neighbors=3)
|
91 |
+
clf.fit(X_train, y_train)
|
92 |
+
|
93 |
+
y_pred = clf.predict(X_test)
|
94 |
+
|
95 |
+
acc = accuracy_score(y_test, y_pred)
|
96 |
+
acc = round(acc, 2)
|
97 |
+
|
98 |
+
ss.session["acc"] = acc # Assign the result to page session variable, the corresponding page display will respond automatically
|
99 |
+
|
100 |
+
# Display the model's accuracy on the test set
|
101 |
+
ss.write(f'Test set predictive accuracy Accuracy =', "\@acc")
|
102 |
+
|
103 |
+
ss.button("Train", onclick=train)
|
104 |
+
|
105 |
+
```
|
106 |
+
''')
|
107 |
+
|
108 |
+
|
109 |
+
ss.md("---")
|
110 |
+
|
111 |
+
|
112 |
+
ss.md('''
|
113 |
+
::: tip
|
114 |
+
### Advantages of KNN:
|
115 |
+
Simple, easy to understand, easy to implement, requires no parameter estimation, no training required;
|
116 |
+
Suitable for classifying rare events;
|
117 |
+
Particularly suitable for multi-class problems (Multi-label, objects with multiple category labels).
|
118 |
+
:::
|
119 |
+
''')
|
120 |
+
|
121 |
+
ss.md('''
|
122 |
+
For more information on KNN, please refer to
|
123 |
+
[KNN Classification Algorithm Introduction, Classifying Iris Dataset with KNN(iris)](https://blog.csdn.net/weixin_51756038/article/details/130096706)
|
124 |
+
''')
|
pages/004 model_sample.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Prediction Instance
|
2 |
+
# The original source code:
|
3 |
+
# https://github.com/AzeemWaqarRao/Streamlit-Iris-Classification-App
|
4 |
+
import simplestart as ss
|
5 |
+
|
6 |
+
from sklearn.datasets import load_iris
|
7 |
+
import pandas as pd
|
8 |
+
import pickle
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
# Data and API
|
12 |
+
species = ['setosa', 'versicolor', 'virginica']
|
13 |
+
image = ['./images/setosa.jpg', './images/versicolor.jpg', './images/virginica.jpg']
|
14 |
+
with open('./data/model.pkl', 'rb') as f:
|
15 |
+
model = pickle.load(f)
|
16 |
+
|
17 |
+
def slidechange(event):
|
18 |
+
predict()
|
19 |
+
|
20 |
+
def predict():
|
21 |
+
# Getting prediction from model
|
22 |
+
inp = np.array([sepal_length.value, sepal_width.value, petal_length.value, petal_width.value])
|
23 |
+
inp = np.expand_dims(inp, axis=0)
|
24 |
+
prediction = model.predict_proba(inp)
|
25 |
+
|
26 |
+
## Show results when prediction is done
|
27 |
+
if True:
|
28 |
+
df = pd.DataFrame(prediction, index=['result'], columns=species).round(4)
|
29 |
+
table_result.data = df
|
30 |
+
ss.session["result"] = species[np.argmax(prediction)]
|
31 |
+
image_flower.image = image[np.argmax(prediction)]
|
32 |
+
|
33 |
+
# UI
|
34 |
+
with ss.sidebar():
|
35 |
+
ss.write("### Inputs")
|
36 |
+
|
37 |
+
sepal_length = ss.slider("sepal length (cm)", 4.3, 7.9, 5.0, onchange=slidechange)
|
38 |
+
sepal_width = ss.slider("sepal width (cm)", 2.0, 4.4, 3.6, onchange=slidechange)
|
39 |
+
petal_length = ss.slider("petal length (cm)", 1.0, 6.9, 1.4, onchange=slidechange)
|
40 |
+
petal_width = ss.slider("petal width (cm)", 0.1, 2.5, 0.2, onchange=slidechange)
|
41 |
+
|
42 |
+
ss.write("## Iris Flower Classification Prediction")
|
43 |
+
ss.write("Change the sepal and petal length and width to predict among the 3 possible categories.")
|
44 |
+
|
45 |
+
ss.write('''
|
46 |
+
# Results
|
47 |
+
Following is the probability of each class
|
48 |
+
''')
|
49 |
+
|
50 |
+
ss.space()
|
51 |
+
|
52 |
+
table_result = ss.table(show_border=True)
|
53 |
+
ss.write("**This flower belongs to the @result" + " class**")
|
54 |
+
|
55 |
+
ss.space()
|
56 |
+
|
57 |
+
image_flower = ss.image(image[0])
|
58 |
+
|
59 |
+
predict()
|