File size: 4,382 Bytes
2f37879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

def load_and_preprocess_mnist():
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    
    x_train = x_train.reshape((-1, 28, 28, 1))
    x_test = x_test.reshape((-1, 28, 28, 1))
    
    y_train = keras.utils.to_categorical(y_train, 10)
    y_test = keras.utils.to_categorical(y_test, 10)
    
    return (x_train, y_train), (x_test, y_test)

def create_mnist_model():
    model = keras.Sequential([
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    
    return model

def train_model(model, x_train, y_train, epochs, batch_size):
    history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    return history

def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    
    return fig

def main():
    st.title("MNIST Digit Classification with Keras and Streamlit")
    
    # Load and preprocess data
    (x_train, y_train), (x_test, y_test) = load_and_preprocess_mnist()
    
    # Create model
    if 'model' not in st.session_state:
        st.session_state.model = create_mnist_model()
    
    # Sidebar for training parameters
    st.sidebar.header("Training Parameters")
    epochs = st.sidebar.slider("Number of Epochs", min_value=1, max_value=50, value=10)
    batch_size = st.sidebar.selectbox("Batch Size", options=[32, 64, 128, 256], index=2)
    
    # Train model button
    if st.sidebar.button("Train Model"):
        with st.spinner("Training in progress..."):
            history = train_model(st.session_state.model, x_train, y_train, epochs, batch_size)
        st.success("Training completed!")
        
        # Plot training history
        st.subheader("Training History")
        fig = plot_training_history(history)
        st.pyplot(fig)
        
        # Evaluate model
        test_loss, test_acc = st.session_state.model.evaluate(x_test, y_test)
        st.write(f"Test accuracy: {test_acc:.4f}")
        
        # Set a flag to indicate the model has been trained
        st.session_state.model_trained = True
    
    # Test on random image
    st.subheader("Test on Random Image")
    if st.button("Select Random Image"):
        if not hasattr(st.session_state, 'model_trained'):
            st.warning("Please train the model first before testing.")
        else:
            # Select a random image from the test set
            idx = np.random.randint(0, x_test.shape[0])
            image = x_test[idx]
            true_label = np.argmax(y_test[idx])
            
            # Make prediction
            prediction = st.session_state.model.predict(image[np.newaxis, ...])[0]
            predicted_label = np.argmax(prediction)
            
            # Display image and prediction
            fig, ax = plt.subplots()
            ax.imshow(image.reshape(28, 28), cmap='gray')
            ax.axis('off')
            st.pyplot(fig)
            
            st.write(f"True Label: {true_label}")
            st.write(f"Predicted Label: {predicted_label}")
            st.write(f"Confidence: {prediction[predicted_label]:.4f}")

if __name__ == "__main__":
    main()