File size: 6,773 Bytes
746d998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
A small Streamlit app that loads a Keras model trained on the MNIST dataset and allows the user to draw a digit on a canvas and get a predicted digit from the model.
"""

import streamlit as st
from PIL import Image
from streamlit_drawable_canvas import st_canvas
import os
import numpy as np
from keras import models
import keras.datasets.mnist as mnist
import matplotlib.pyplot as plt
import pandas as pd
import time
import onnx
import onnxruntime
from scipy.special import softmax


@st.cache_resource
def load_picture():
    """
    Loads the first 9 images from the mnist dataset and add them to a plot
    to be displayed in streamlit.
    """
    # load the mnist dataset
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    # plot the first 9 images
    for i in range(9):
        plt.subplot(330 + 1 + i)
        image = x_train[i] / 255.0
        plt.imshow(image, cmap=plt.get_cmap("gray"))

    # Save the plot as a png file and show it in streamlit
    # This is commented out for not because the plot was created and saved in the img directory during the initial run of the app locally
    # plt.savefig("img/show.png")
    st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")


def keras_prediction(final, model_path):
    load_time = time.time()
    model = models.load_model(
        os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
    )
    after_load_curr = time.time()
    curr_time = time.time()
    prediction = model.predict(final[None, ...])
    after_time = time.time()
    return prediction, after_time - curr_time, after_load_curr - load_time


def onnx_prediction(final, model_path):
    im_np = np.expand_dims(final, axis=0)  # Add batch dimension
    im_np = np.expand_dims(im_np, axis=0)  # Add channel dimension
    im_np = im_np.astype("float32")
    load_curr = time.time()
    session = onnxruntime.InferenceSession(model_path, None)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    after_load_curr = time.time()

    curr_time = time.time()
    result = session.run([output_name], {input_name: im_np})
    prediction = softmax(np.array(result).squeeze(), axis=0)
    after_time = time.time()
    return prediction, after_time - curr_time, after_load_curr - load_curr


def main():
    """
    The main function/primary entry point of the app
    """
    # write the title of the page as MNIST Digit Recognizer
    st.title("MNIST Digit Recognizer")

    col1, col2 = st.columns([0.8, 0.2], gap="small")
    with col1:
        st.markdown(
            """
            This Streamlit app loads a Keras neural network trained on the MNIST dataset to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can: 
            - Change the stroke width of the digit using the slider
            - Choose what model you use for predictions
                - Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
                - Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
                - Basic: A simple two layer nueral net where each layer has 300 nodes
            
            Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.""",
            unsafe_allow_html=True,
        )
    with col2:
        # Load the first 9 images from the MNIST dataset and show them
        load_picture()

    col3, col4 = st.columns(2, gap="small")

    with col4:
        # Stroke width slider to change the width of the canvas stroke
        # Starts at 10 because that's reasonably close to the width of the MNIST digits
        stroke_width = st.slider("Stroke width: ", 1, 25, 10)
        model_choice = st.selectbox(
            "Choose what model to use for predictions:", ("Onnx", "Autokeras", "Basic")
        )
        if "Basic" in model_choice:
            model_path = "models/mnist_model.keras"

        if "Auto" in model_choice:
            model_path = "models/autokeras_model.keras"

        if "Onnx" in model_choice:
            model_path = "models/mnist_12.onnx"

    with col3:
        # Create a canvas component
        canvas_result = st_canvas(
            stroke_width=stroke_width,
            stroke_color="#FFF",
            fill_color="#000",
            background_color="#000",
            background_image=None,
            update_streamlit=True,
            height=200,
            width=200,
            drawing_mode="freedraw",
            point_display_radius=0,
            key="canvas",
        )

    if canvas_result is not None and canvas_result.image_data is not None:

        # Get the image data, convert it to grayscale, and resize it to 28x28 (the same size as the MNIST dataset images)
        img_data = canvas_result.image_data
        im = Image.fromarray(img_data.astype("uint8")).convert("L")
        im = im.resize((28, 28))

        # Convert the image to a numpy array and normalize the values
        final = np.array(im, dtype=np.float32) / 255.0

        # if final is not all zeros, run the prediction
        if not np.all(final == 0):

            if model_choice != "Onnx":
                prediction, pred_time, load_time = keras_prediction(final, model_path)
            else:
                prediction, pred_time, load_time = onnx_prediction(final, model_path)

            # print the prediction
            st.header(f"Using model: {model_choice}")
            st.write(f"Prediction: {np.argmax(prediction)}")
            st.write(f"Load time (in ms): {(load_time) * 1000:.2f}")
            st.write(f"Prediction time (in ms): {(pred_time) * 1000:.2f}")

            # Create a 2 column dataframe with one column as the digits and the other as the probability
            data = pd.DataFrame(
                {"Digit": list(range(10)), "Probability": np.ravel(prediction)}
            )

            col1, col2 = st.columns([0.8, 0.2], gap="small")
            # create a bar chart to show the predictions
            with col1:
                st.bar_chart(data, x="Digit", y="Probability", height=500)

            # show the probability distribution numerically
            with col2:
                data["Probability"] = data["Probability"].apply(lambda x: f"{x:.2%}")
                st.dataframe(data, hide_index=True)


if __name__ == "__main__":
    main()