import streamlit as st import tensorflow as tf from tensorflow.keras import layers, models, applications from tensorflow.keras.preprocessing.image import ImageDataGenerator import matplotlib.pyplot as plt import numpy as np # Set dataset paths train_dir = 'data/train' validation_dir = 'data/validation' # Streamlit app st.title("Transfer Learning with VGG16 for Image Classification") # Input parameters batch_size = st.slider("Batch Size", 16, 128, 32, 16) epochs = st.slider("Epochs", 5, 50, 10, 5) # Data augmentation and preprocessing train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest' ) validation_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_dir, target_size=(150, 150), batch_size=batch_size, class_mode='binary' ) validation_generator = validation_datagen.flow_from_directory( validation_dir, target_size=(150, 150), batch_size=batch_size, class_mode='binary' ) # Load the pre-trained VGG16 model base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3)) # Freeze the convolutional base base_model.trainable = False # Add custom layers on top model = models.Sequential([ base_model, layers.Flatten(), layers.Dense(256, activation='relu'), layers.Dropout(0.5), layers.Dense(1, activation='sigmoid') # Change the output layer based on the number of classes ]) model.summary() # Compile the model model.compile(optimizer='adam', loss='binary_crossentropy', # Change loss function based on the number of classes metrics=['accuracy']) # Train the model if st.button("Train Model"): with st.spinner("Training the model..."): history = model.fit( train_generator, steps_per_epoch=train_generator.samples // train_generator.batch_size, epochs=epochs, validation_data=validation_generator, validation_steps=validation_generator.samples // validation_generator.batch_size ) st.success("Model training completed!") # Display training curves st.subheader("Training and Validation Accuracy") fig, ax = plt.subplots() ax.plot(history.history['accuracy'], label='Training Accuracy') ax.plot(history.history['val_accuracy'], label='Validation Accuracy') ax.set_xlabel('Epoch') ax.set_ylabel('Accuracy') ax.legend() st.pyplot(fig) st.subheader("Training and Validation Loss") fig, ax = plt.subplots() ax.plot(history.history['loss'], label='Training Loss') ax.plot(history.history['val_loss'], label='Validation Loss') ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.legend() st.pyplot(fig) # Evaluate the model if st.button("Evaluate Model"): test_loss, test_acc = model.evaluate(validation_generator, verbose=2) st.write(f"Validation accuracy: {test_acc}")