Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
from tensorflow import keras | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def load_and_preprocess_mnist(): | |
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | |
x_train = x_train.astype('float32') / 255.0 | |
x_test = x_test.astype('float32') / 255.0 | |
x_train = x_train.reshape((-1, 28, 28, 1)) | |
x_test = x_test.reshape((-1, 28, 28, 1)) | |
y_train = keras.utils.to_categorical(y_train, 10) | |
y_test = keras.utils.to_categorical(y_test, 10) | |
return (x_train, y_train), (x_test, y_test) | |
def create_mnist_model(): | |
model = keras.Sequential([ | |
keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)), | |
keras.layers.MaxPooling2D(pool_size=(2, 2)), | |
keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'), | |
keras.layers.MaxPooling2D(pool_size=(2, 2)), | |
keras.layers.Flatten(), | |
keras.layers.Dropout(0.5), | |
keras.layers.Dense(64, activation='relu'), | |
keras.layers.Dense(10, activation='softmax') | |
]) | |
model.compile(optimizer='adam', | |
loss='categorical_crossentropy', | |
metrics=['accuracy']) | |
return model | |
def train_model(model, x_train, y_train, epochs, batch_size): | |
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) | |
return history | |
def plot_training_history(history): | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
ax1.plot(history.history['accuracy'], label='Training Accuracy') | |
ax1.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
ax1.set_title('Model Accuracy') | |
ax1.set_xlabel('Epoch') | |
ax1.set_ylabel('Accuracy') | |
ax1.legend() | |
ax2.plot(history.history['loss'], label='Training Loss') | |
ax2.plot(history.history['val_loss'], label='Validation Loss') | |
ax2.set_title('Model Loss') | |
ax2.set_xlabel('Epoch') | |
ax2.set_ylabel('Loss') | |
ax2.legend() | |
return fig | |
def main(): | |
st.title("MNIST Digit Classification with Keras and Streamlit") | |
# Load and preprocess data | |
(x_train, y_train), (x_test, y_test) = load_and_preprocess_mnist() | |
# Create model | |
if 'model' not in st.session_state: | |
st.session_state.model = create_mnist_model() | |
# Sidebar for training parameters | |
st.sidebar.header("Training Parameters") | |
epochs = st.sidebar.slider("Number of Epochs", min_value=1, max_value=50, value=10) | |
batch_size = st.sidebar.selectbox("Batch Size", options=[32, 64, 128, 256], index=2) | |
# Train model button | |
if st.sidebar.button("Train Model"): | |
with st.spinner("Training in progress..."): | |
history = train_model(st.session_state.model, x_train, y_train, epochs, batch_size) | |
st.success("Training completed!") | |
# Plot training history | |
st.subheader("Training History") | |
fig = plot_training_history(history) | |
st.pyplot(fig) | |
# Evaluate model | |
test_loss, test_acc = st.session_state.model.evaluate(x_test, y_test) | |
st.write(f"Test accuracy: {test_acc:.4f}") | |
# Set a flag to indicate the model has been trained | |
st.session_state.model_trained = True | |
# Test on random image | |
st.subheader("Test on Random Image") | |
if st.button("Select Random Image"): | |
if not hasattr(st.session_state, 'model_trained'): | |
st.warning("Please train the model first before testing.") | |
else: | |
# Select a random image from the test set | |
idx = np.random.randint(0, x_test.shape[0]) | |
image = x_test[idx] | |
true_label = np.argmax(y_test[idx]) | |
# Make prediction | |
prediction = st.session_state.model.predict(image[np.newaxis, ...])[0] | |
predicted_label = np.argmax(prediction) | |
# Display image and prediction | |
fig, ax = plt.subplots() | |
ax.imshow(image.reshape(28, 28), cmap='gray') | |
ax.axis('off') | |
st.pyplot(fig) | |
st.write(f"True Label: {true_label}") | |
st.write(f"Predicted Label: {predicted_label}") | |
st.write(f"Confidence: {prediction[predicted_label]:.4f}") | |
if __name__ == "__main__": | |
main() |