eaglelandsonce commited on
Commit
4c05b1a
·
verified ·
1 Parent(s): 3c86a58

Create 9_Cifar_10.py

Browse files
Files changed (1) hide show
  1. pages/9_Cifar_10.py +96 -0
pages/9_Cifar_10.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ from tensorflow.keras import datasets, layers, models
4
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ # Define the CNN model
10
+ def create_cnn_model():
11
+ model = models.Sequential()
12
+ model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
13
+ model.add(layers.MaxPooling2D((2, 2)))
14
+ model.add(layers.Conv2D(64, (3, 3), activation='relu'))
15
+ model.add(layers.MaxPooling2D((2, 2)))
16
+ model.add(layers.Conv2D(64, (3, 3), activation='relu'))
17
+ model.add(layers.Flatten())
18
+ model.add(layers.Dense(64, activation='relu'))
19
+ model.add(layers.Dropout(0.5))
20
+ model.add(layers.Dense(10, activation='softmax'))
21
+ return model
22
+
23
+ # Streamlit app
24
+ st.title("CIFAR-10 Image Classification with CNN")
25
+
26
+ # Load CIFAR-10 data
27
+ (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
28
+ train_images, test_images = train_images / 255.0, test_images / 255.0
29
+
30
+ # Display sample images
31
+ st.subheader("Sample Training Images")
32
+ fig, ax = plt.subplots(1, 5, figsize=(15, 3))
33
+ for i in range(5):
34
+ ax[i].imshow(train_images[i])
35
+ ax[i].axis('off')
36
+ st.pyplot(fig)
37
+
38
+ # Model creation
39
+ model = create_cnn_model()
40
+
41
+ # Compile the model
42
+ model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
43
+
44
+ # Data augmentation
45
+ datagen = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
46
+ datagen.fit(train_images)
47
+
48
+ # Training parameters
49
+ batch_size = st.slider("Batch Size", 32, 128, 64, 32)
50
+ epochs = st.slider("Epochs", 10, 50, 20, 10)
51
+
52
+ # Train button
53
+ if st.button("Train Model"):
54
+ with st.spinner("Training the model..."):
55
+ history = model.fit(datagen.flow(train_images, train_labels, batch_size=batch_size),
56
+ steps_per_epoch=len(train_images) / batch_size,
57
+ epochs=epochs,
58
+ validation_data=(test_images, test_labels))
59
+
60
+ st.success("Model training completed!")
61
+
62
+ # Display training curves
63
+ st.subheader("Training and Validation Accuracy")
64
+ fig, ax = plt.subplots()
65
+ ax.plot(history.history['accuracy'], label='Training Accuracy')
66
+ ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
67
+ ax.set_xlabel('Epoch')
68
+ ax.set_ylabel('Accuracy')
69
+ ax.legend()
70
+ st.pyplot(fig)
71
+
72
+ st.subheader("Training and Validation Loss")
73
+ fig, ax = plt.subplots()
74
+ ax.plot(history.history['loss'], label='Training Loss')
75
+ ax.plot(history.history['val_loss'], label='Validation Loss')
76
+ ax.set_xlabel('Epoch')
77
+ ax.set_ylabel('Loss')
78
+ ax.legend()
79
+ st.pyplot(fig)
80
+
81
+ # Prediction on uploaded image
82
+ st.subheader("Make Predictions")
83
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
84
+
85
+ if uploaded_file is not None:
86
+ # Preprocess the uploaded image
87
+ image = Image.open(uploaded_file)
88
+ image = image.resize((32, 32))
89
+ image_array = np.array(image) / 255.0
90
+
91
+ st.image(image, caption='Uploaded Image', use_column_width=True)
92
+
93
+ if st.button("Predict"):
94
+ prediction = model.predict(np.expand_dims(image_array, axis=0))
95
+ predicted_class = np.argmax(prediction)
96
+ st.write(f"Predicted Class: {predicted_class} ({class_names[predicted_class]})")