import streamlit as st
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

def load_and_preprocess_mnist():
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    
    x_train = x_train.reshape((-1, 28, 28, 1))
    x_test = x_test.reshape((-1, 28, 28, 1))
    
    y_train = keras.utils.to_categorical(y_train, 10)
    y_test = keras.utils.to_categorical(y_test, 10)
    
    return (x_train, y_train), (x_test, y_test)

def create_mnist_model():
    model = keras.Sequential([
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    
    return model

def train_model(model, x_train, y_train, epochs, batch_size):
    history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    return history

def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    
    return fig

def main():
    st.title("MNIST Digit Classification with Keras and Streamlit")
    
    # Load and preprocess data
    (x_train, y_train), (x_test, y_test) = load_and_preprocess_mnist()
    
    # Create model
    if 'model' not in st.session_state:
        st.session_state.model = create_mnist_model()
    
    # Sidebar for training parameters
    st.sidebar.header("Training Parameters")
    epochs = st.sidebar.slider("Number of Epochs", min_value=1, max_value=50, value=10)
    batch_size = st.sidebar.selectbox("Batch Size", options=[32, 64, 128, 256], index=2)
    
    # Train model button
    if st.sidebar.button("Train Model"):
        with st.spinner("Training in progress..."):
            history = train_model(st.session_state.model, x_train, y_train, epochs, batch_size)
        st.success("Training completed!")
        
        # Plot training history
        st.subheader("Training History")
        fig = plot_training_history(history)
        st.pyplot(fig)
        
        # Evaluate model
        test_loss, test_acc = st.session_state.model.evaluate(x_test, y_test)
        st.write(f"Test accuracy: {test_acc:.4f}")
        
        # Set a flag to indicate the model has been trained
        st.session_state.model_trained = True
    
    # Test on random image
    st.subheader("Test on Random Image")
    if st.button("Select Random Image"):
        if not hasattr(st.session_state, 'model_trained'):
            st.warning("Please train the model first before testing.")
        else:
            # Select a random image from the test set
            idx = np.random.randint(0, x_test.shape[0])
            image = x_test[idx]
            true_label = np.argmax(y_test[idx])
            
            # Make prediction
            prediction = st.session_state.model.predict(image[np.newaxis, ...])[0]
            predicted_label = np.argmax(prediction)
            
            # Display image and prediction
            fig, ax = plt.subplots()
            ax.imshow(image.reshape(28, 28), cmap='gray')
            ax.axis('off')
            st.pyplot(fig)
            
            st.write(f"True Label: {true_label}")
            st.write(f"Predicted Label: {predicted_label}")
            st.write(f"Confidence: {prediction[predicted_label]:.4f}")

if __name__ == "__main__":
    main()