TensorFlowClass / pages /13_TransferLearning.py
eaglelandsonce's picture
Update pages/13_TransferLearning.py
6d89f79 verified
raw
history blame
2.6 kB
import streamlit as st
import tensorflow as tf
from tensorflow.keras import layers, models, applications
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
# Load the dataset
dataset_name = "cats_vs_dogs"
(ds_train, ds_val), ds_info = tfds.load(dataset_name, split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True)
# Preprocess the dataset
def preprocess_image(image, label):
image = tf.image.resize(image, (150, 150))
image = image / 255.0
return image, label
ds_train = ds_train.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)
ds_val = ds_val.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)
# Streamlit app
st.title("Transfer Learning with VGG16 for Image Classification")
# Input parameters
batch_size = st.slider("Batch Size", 16, 128, 32, 16)
epochs = st.slider("Epochs", 5, 50, 10, 5)
# Load the pre-trained VGG16 model
base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
# Freeze the convolutional base
base_model.trainable = False
# Add custom layers on top
model = models.Sequential([
base_model,
layers.Flatten(),
layers.Dense(256, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid') # Change the output layer based on the number of classes
])
model.summary()
# Compile the model
model.compile(optimizer='adam',
loss='binary_crossentropy', # Change loss function based on the number of classes
metrics=['accuracy'])
# Train the model
if st.button("Train Model"):
with st.spinner("Training the model..."):
history = model.fit(
ds_train,
epochs=epochs,
validation_data=ds_val
)
st.success("Model training completed!")
# Display training curves
st.subheader("Training and Validation Accuracy")
fig, ax = plt.subplots()
ax.plot(history.history['accuracy'], label='Training Accuracy')
ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
st.pyplot(fig)
st.subheader("Training and Validation Loss")
fig, ax = plt.subplots()
ax.plot(history.history['loss'], label='Training Loss')
ax.plot(history.history['val_loss'], label='Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
st.pyplot(fig)
# Evaluate the model
if st.button("Evaluate Model"):
test_loss, test_acc = model.evaluate(ds_val, verbose=2)
st.write(f"Validation accuracy: {test_acc}")