Spaces:
Sleeping
Sleeping
File size: 7,042 Bytes
4e1e636 1f22a6d 4e1e636 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import plotly.express as px
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
import gradio as gr
# Load data from https://www.openml.org/d/554
X, y = fetch_openml(
"mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)
print("Data loaded")
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))
scaler = StandardScaler()
def dataset_display(digit, count_per_digit, binary_image):
if digit not in range(10):
# return a figure displaying an error message
return px.imshow(
np.zeros((28, 28)),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Digit {digit} is not in the data",
)
binary_value = True if binary_image == 1 else False
digit_idxs = np.where(y == str(digit))[0]
random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False)
fig = px.imshow(
np.array([X[i].reshape(28, 28) for i in random_idxs]),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Examples of Digit {digit} in Data",
facet_col=0,
facet_col_wrap=5,
binary_string=binary_value,
)
return fig
def predict(img):
try:
img = img.reshape(1, -1)
except:
return "Show Your Drawing Skills"
try:
img = scaler.transform(img)
prediction = clf.predict(img)
return prediction[0]
except:
return "Train the model first"
def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"):
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=train_sample, test_size=10000
)
penalty_dict = {
"l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"],
"l1": ["liblinear", "saga"],
"elasticnet": ["saga"],
}
if solver not in penalty_dict[penalty]:
return (
"Solver not supported for the selected penalty",
"Change the Combination",
None,
)
global clf
global scaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol)
clf.fit(X_train, y_train)
sparsity = np.mean(clf.coef_ == 0) * 100
score = clf.score(X_test, y_test)
coef = clf.coef_.copy()
scale = np.abs(coef).max()
fig = px.imshow(
np.array([coef[i].reshape(28, 28) for i in range(10)]),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Classification vector for each digit",
range_color=[-scale, scale],
facet_col=0,
facet_col_wrap=5,
facet_col_spacing=0.01,
color_continuous_scale="RdBu",
zmin=-scale,
zmax=scale,
)
return score, sparsity, fig
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Phân loại dữ liệu MNIST bằng mô hình logistic đa thức và chính quy hóa L1")
gr.Markdown(
"""Mục tiêu chính của bản demo này là giới thiệu cách sử dụng hồi quy logistic trong việc phân loại các chữ số viết tay từ tập dữ liệu [MNIST](https://en.wikipedia.org/wiki/MNIST_database), một tập dữ liệu điểm chuẩn nổi tiếng trong máy tính tầm nhìn. Tập dữ liệu được tải từ [OpenML](https://www.openml.org/d/554), đây là một nền tảng mở dành cho nghiên cứu máy học giúp dễ dàng truy cập vào số lượng lớn tập dữ liệu.
Mô hình này được đào tạo bằng thư viện scikit-learn, thư viện này cung cấp nhiều công cụ cho máy học, bao gồm các thuật toán phân loại, hồi quy và phân cụm, cũng như các công cụ tiền xử lý dữ liệu và đánh giá mô hình. Bản demo tính toán điểm số và số liệu thưa thớt bằng cách sử dụng dữ liệu thử nghiệm, cung cấp thông tin chi tiết tương ứng về hiệu suất và độ thưa thớt của mô hình. Số liệu điểm cho biết mô hình đang hoạt động tốt như thế nào, trong khi số liệu thưa thớt cung cấp thông tin về số hệ số khác 0 trong mô hình, có thể hữu ích cho việc diễn giải mô hình và giảm độ phức tạp của nó.
"""
)
with gr.Tab("Khám phá dữ liệu"):
gr.Markdown("## ")
with gr.Row():
digit = gr.Slider(0, 9, label="Lựa chọn số", value=5, step=1)
count_per_digit = gr.Slider(
1, 10, label="Số lượng ảnh", value=10, step=1
)
binary_image = gr.Slider(0, 1, label="Phân loại ảnh nhị phân", value=0, step=1)
gen_btn = gr.Button("Hiển thị")
gen_btn.click(
dataset_display,
inputs=[digit, count_per_digit, binary_image],
outputs=gr.Plot(),
)
with gr.Tab("Huấn luyện mô hình"):
gr.Markdown("# Thay đổi các tham số để xem mô hình thay đổi như thế nào")
gr.Markdown("## Solver and penalty")
gr.Markdown(
"""
Penalty | Solver
-------|---------------
l1 | saga
l2 | saga
"""
)
with gr.Row():
train_sample = gr.Slider(
1000, 60000, label="Số lượng dữ liệu huấn luyện", value=5000, step=1
)
c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1)
tol = gr.Slider(
0.1, 1, label="Dung sai cho tiêu chí dừng.", value=0.1, step=0.1
)
max_iter = gr.Slider(100, 1000, label="Số vòng huấn luyện", value=100, step=1)
penalty = gr.Dropdown(
["l1", "l2",], label="Chính quy hóa", value="l1"
)
solver = gr.Dropdown(
["saga"],
label="Thuật toán",
value="saga",
)
train_btn = gr.Button("Huấn luyện")
train_btn.click(
train_model,
inputs=[train_sample, c, tol, solver, penalty],
outputs=[
gr.Textbox(label="Độ chính xác"),
gr.Textbox(label="Độ thưa thớt"),
gr.Plot(),
],
)
with gr.Tab("Dự đoán số mới"):
gr.Markdown("## Draw a digit and see the model's prediction")
inputs = gr.Sketchpad(brush_radius=1.0)
outputs = gr.Textbox(label="Kết quả dự đoán", lines=1)
skecth_btn = gr.Button("Dự đoán ảnh vẽ tay")
skecth_btn.click(predict, inputs, outputs)
demo.launch() |