File size: 7,042 Bytes
4e1e636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f22a6d
4e1e636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import plotly.express as px
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
import gradio as gr


# Load data from https://www.openml.org/d/554
X, y = fetch_openml(
    "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)

print("Data loaded")
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))


scaler = StandardScaler()


def dataset_display(digit, count_per_digit, binary_image):
    if digit not in range(10):
        # return a figure displaying an error message
        return px.imshow(
            np.zeros((28, 28)),
            labels=dict(x="Pixel columns", y="Pixel rows"),
            title=f"Digit {digit} is not in the data",
        )

    binary_value = True if binary_image == 1 else False
    digit_idxs = np.where(y == str(digit))[0]
    random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False)

    fig = px.imshow(
        np.array([X[i].reshape(28, 28) for i in random_idxs]),
        labels=dict(x="Pixel columns", y="Pixel rows"),
        title=f"Examples of Digit {digit} in Data",
        facet_col=0,
        facet_col_wrap=5,
        binary_string=binary_value,
    )

    return fig


def predict(img):
    try:
        img = img.reshape(1, -1)
    except:
        return "Show Your Drawing Skills"
    try:
        img = scaler.transform(img)
        prediction = clf.predict(img)
        return prediction[0]
    except:
        return "Train the model first"


def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"):
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=train_sample, test_size=10000
    )

    penalty_dict = {
        "l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"],
        "l1": ["liblinear", "saga"],
        "elasticnet": ["saga"],
    }

    if solver not in penalty_dict[penalty]:
        return (
            "Solver not supported for the selected penalty",
            "Change the Combination",
            None,
        )

    global clf
    global scaler
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol)
    clf.fit(X_train, y_train)
    sparsity = np.mean(clf.coef_ == 0) * 100
    score = clf.score(X_test, y_test)

    coef = clf.coef_.copy()
    scale = np.abs(coef).max()

    fig = px.imshow(
        np.array([coef[i].reshape(28, 28) for i in range(10)]),
        labels=dict(x="Pixel columns", y="Pixel rows"),
        title=f"Classification vector for each digit",
        range_color=[-scale, scale],
        facet_col=0,
        facet_col_wrap=5,
        facet_col_spacing=0.01,
        color_continuous_scale="RdBu",
        zmin=-scale,
        zmax=scale,
    )

    return score, sparsity, fig


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Phân loại dữ liệu MNIST bằng mô hình logistic đa thức và chính quy hóa L1")
    gr.Markdown(
        """Mục tiêu chính của bản demo này là giới thiệu cách sử dụng hồi quy logistic trong việc phân loại các chữ số viết tay từ tập dữ liệu [MNIST](https://en.wikipedia.org/wiki/MNIST_database), một tập dữ liệu điểm chuẩn nổi tiếng trong máy tính tầm nhìn. Tập dữ liệu được tải từ [OpenML](https://www.openml.org/d/554), đây là một nền tảng mở dành cho nghiên cứu máy học giúp dễ dàng truy cập vào số lượng lớn tập dữ liệu.
Mô hình này được đào tạo bằng thư viện scikit-learn, thư viện này cung cấp nhiều công cụ cho máy học, bao gồm các thuật toán phân loại, hồi quy và phân cụm, cũng như các công cụ tiền xử lý dữ liệu và đánh giá mô hình. Bản demo tính toán điểm số và số liệu thưa thớt bằng cách sử dụng dữ liệu thử nghiệm, cung cấp thông tin chi tiết tương ứng về hiệu suất và độ thưa thớt của mô hình. Số liệu điểm cho biết mô hình đang hoạt động tốt như thế nào, trong khi số liệu thưa thớt cung cấp thông tin về số hệ số khác 0 trong mô hình, có thể hữu ích cho việc diễn giải mô hình và giảm độ phức tạp của nó.
    """
    )

    with gr.Tab("Khám phá dữ liệu"):
        gr.Markdown("## ")
        with gr.Row():
            digit = gr.Slider(0, 9, label="Lựa chọn số", value=5, step=1)
            count_per_digit = gr.Slider(
                1, 10, label="Số lượng ảnh", value=10, step=1
            )
            binary_image = gr.Slider(0, 1, label="Phân loại ảnh nhị phân", value=0, step=1)

        gen_btn = gr.Button("Hiển thị")
        gen_btn.click(
            dataset_display,
            inputs=[digit, count_per_digit, binary_image],
            outputs=gr.Plot(),
        )

    with gr.Tab("Huấn luyện mô hình"):
        gr.Markdown("# Thay đổi các tham số để xem mô hình thay đổi như thế nào")

        gr.Markdown("## Solver and penalty")
        gr.Markdown(
            """
        Penalty | Solver
        -------|---------------
        l1 | saga
        l2  | saga
        """
        )

        with gr.Row():
            train_sample = gr.Slider(
                1000, 60000, label="Số lượng dữ liệu huấn luyện", value=5000, step=1
            )

            c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1)
            tol = gr.Slider(
                0.1, 1, label="Dung sai cho tiêu chí dừng.", value=0.1, step=0.1
            )
            max_iter = gr.Slider(100, 1000, label="Số vòng huấn luyện", value=100, step=1)

            penalty = gr.Dropdown(
                ["l1", "l2",], label="Chính quy hóa", value="l1"
            )
            solver = gr.Dropdown(
                ["saga"],
                label="Thuật toán",
                value="saga",
            )

        train_btn = gr.Button("Huấn luyện")
        train_btn.click(
            train_model,
            inputs=[train_sample, c, tol, solver, penalty],
            outputs=[
                gr.Textbox(label="Độ chính xác"),
                gr.Textbox(label="Độ thưa thớt"),
                gr.Plot(),
            ],
        )

    with gr.Tab("Dự đoán số mới"):
        gr.Markdown("## Draw a digit and see the model's prediction")
        inputs = gr.Sketchpad(brush_radius=1.0)
        outputs = gr.Textbox(label="Kết quả dự đoán", lines=1)
        skecth_btn = gr.Button("Dự đoán ảnh vẽ tay")
        skecth_btn.click(predict, inputs, outputs)


demo.launch()