TensorFlowClass / pages /7_mnist.py
eaglelandsonce's picture
Create 7_mnist.py
2f37879 verified
raw
history blame
4.38 kB
import streamlit as st
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
def load_and_preprocess_mnist():
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train.reshape((-1, 28, 28, 1))
x_test = x_test.reshape((-1, 28, 28, 1))
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
return (x_train, y_train), (x_test, y_test)
def create_mnist_model():
model = keras.Sequential([
keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dropout(0.5),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
def train_model(model, x_train, y_train, epochs, batch_size):
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
return history
def plot_training_history(history):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['accuracy'], label='Training Accuracy')
ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax1.set_title('Model Accuracy')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax2.plot(history.history['loss'], label='Training Loss')
ax2.plot(history.history['val_loss'], label='Validation Loss')
ax2.set_title('Model Loss')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
return fig
def main():
st.title("MNIST Digit Classification with Keras and Streamlit")
# Load and preprocess data
(x_train, y_train), (x_test, y_test) = load_and_preprocess_mnist()
# Create model
if 'model' not in st.session_state:
st.session_state.model = create_mnist_model()
# Sidebar for training parameters
st.sidebar.header("Training Parameters")
epochs = st.sidebar.slider("Number of Epochs", min_value=1, max_value=50, value=10)
batch_size = st.sidebar.selectbox("Batch Size", options=[32, 64, 128, 256], index=2)
# Train model button
if st.sidebar.button("Train Model"):
with st.spinner("Training in progress..."):
history = train_model(st.session_state.model, x_train, y_train, epochs, batch_size)
st.success("Training completed!")
# Plot training history
st.subheader("Training History")
fig = plot_training_history(history)
st.pyplot(fig)
# Evaluate model
test_loss, test_acc = st.session_state.model.evaluate(x_test, y_test)
st.write(f"Test accuracy: {test_acc:.4f}")
# Set a flag to indicate the model has been trained
st.session_state.model_trained = True
# Test on random image
st.subheader("Test on Random Image")
if st.button("Select Random Image"):
if not hasattr(st.session_state, 'model_trained'):
st.warning("Please train the model first before testing.")
else:
# Select a random image from the test set
idx = np.random.randint(0, x_test.shape[0])
image = x_test[idx]
true_label = np.argmax(y_test[idx])
# Make prediction
prediction = st.session_state.model.predict(image[np.newaxis, ...])[0]
predicted_label = np.argmax(prediction)
# Display image and prediction
fig, ax = plt.subplots()
ax.imshow(image.reshape(28, 28), cmap='gray')
ax.axis('off')
st.pyplot(fig)
st.write(f"True Label: {true_label}")
st.write(f"Predicted Label: {predicted_label}")
st.write(f"Confidence: {prediction[predicted_label]:.4f}")
if __name__ == "__main__":
main()