TensorFlowClass / pages /9_Cifar_10.py
eaglelandsonce's picture
Create 9_Cifar_10.py
4c05b1a verified
raw
history blame
3.39 kB
import streamlit as st
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
# Define the CNN model
def create_cnn_model():
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(10, activation='softmax'))
return model
# Streamlit app
st.title("CIFAR-10 Image Classification with CNN")
# Load CIFAR-10 data
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
# Display sample images
st.subheader("Sample Training Images")
fig, ax = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
ax[i].imshow(train_images[i])
ax[i].axis('off')
st.pyplot(fig)
# Model creation
model = create_cnn_model()
# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Data augmentation
datagen = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
datagen.fit(train_images)
# Training parameters
batch_size = st.slider("Batch Size", 32, 128, 64, 32)
epochs = st.slider("Epochs", 10, 50, 20, 10)
# Train button
if st.button("Train Model"):
with st.spinner("Training the model..."):
history = model.fit(datagen.flow(train_images, train_labels, batch_size=batch_size),
steps_per_epoch=len(train_images) / batch_size,
epochs=epochs,
validation_data=(test_images, test_labels))
st.success("Model training completed!")
# Display training curves
st.subheader("Training and Validation Accuracy")
fig, ax = plt.subplots()
ax.plot(history.history['accuracy'], label='Training Accuracy')
ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
st.pyplot(fig)
st.subheader("Training and Validation Loss")
fig, ax = plt.subplots()
ax.plot(history.history['loss'], label='Training Loss')
ax.plot(history.history['val_loss'], label='Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
st.pyplot(fig)
# Prediction on uploaded image
st.subheader("Make Predictions")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Preprocess the uploaded image
image = Image.open(uploaded_file)
image = image.resize((32, 32))
image_array = np.array(image) / 255.0
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button("Predict"):
prediction = model.predict(np.expand_dims(image_array, axis=0))
predicted_class = np.argmax(prediction)
st.write(f"Predicted Class: {predicted_class} ({class_names[predicted_class]})")