import streamlit as st import tensorflow as tf from tensorflow import keras import numpy as np from PIL import Image import io from streamlit_drawable_canvas import st_canvas # Set page config st.set_page_config(page_title="MNIST Digit Recognition", page_icon="✏️", layout="centered") # Load the saved model @st.cache_resource() def load_model(): return tf.keras.models.load_model('mnist_model.h5', custom_objects={'InputLayer': keras.layers.InputLayer}, compile=False) model = load_model() # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Create a Streamlit app st.markdown("

MNIST Digit Recognition

", unsafe_allow_html=True) st.write("Draw a digit or upload an image to see the model's prediction!") # Add option to choose between drawing and uploading option = st.radio("Choose input method:", ('Draw', 'Upload')) if option == 'Draw': # Create a canvas component canvas_result = st_canvas( fill_color="rgba(255, 255, 255, 0.3)", stroke_width=20, stroke_color="#FFFFFF", background_color="#000000", height=280, width=280, drawing_mode="freedraw", key="canvas", ) if canvas_result.image_data is not None: image = Image.fromarray(canvas_result.image_data.astype('uint8')) else: image = None else: # Upload option uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(io.BytesIO(uploaded_file.read())) st.image(image, caption='Uploaded Image.', use_column_width=True) else: st.write("Please upload an image.") image = None # Add a button to make a prediction if st.button('Predict', key='predict_button'): if image is not None: # Preprocess the image image = image.convert('L') # Convert to grayscale image = image.resize((28, 28)) image_array = np.array(image) / 255.0 # Normalize image_array = image_array.reshape(1, 28, 28, 1).astype('float32') # Make a prediction with st.spinner('Predicting...'): prediction = model.predict(image_array) predicted_digit = np.argmax(prediction) # Display the results st.markdown(f"

Predicted Digit: {predicted_digit}

", unsafe_allow_html=True) #st.balloons() else: st.warning("Please draw or upload an image before predicting.") st.markdown("---") st.markdown("", unsafe_allow_html=True)