import streamlit as st import tensorflow as tf from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification import tensorflow_datasets as tfds import matplotlib.pyplot as plt # Load the dataset dataset_name = "cats_vs_dogs" (ds_train, ds_val), ds_info = tfds.load(dataset_name, split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True) # Preprocess the dataset def preprocess_image(image, label): image = tf.image.resize(image, (224, 224)) # ViT requires 224x224 images image = image / 255.0 return image, label ds_train = ds_train.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE) ds_val = ds_val.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE) # Streamlit app st.title("Transfer Learning with Vision Transformer for Image Classification") # Input parameters batch_size = st.slider("Batch Size", 16, 128, 32, 16) epochs = st.slider("Epochs", 5, 50, 10, 5) # Load the pre-trained Vision Transformer model model_name = "google/vit-base-patch16-224-in21k" feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes # Freeze the base model base_model.trainable = False # Function to extract features using the feature extractor def extract_features(images): # Convert images to the expected format for the feature extractor images = [tf.image.convert_image_dtype(image, tf.float32) for image in images] inputs = feature_extractor(images, return_tensors="tf") return inputs["pixel_values"] # Add custom layers on top inputs = tf.keras.Input(shape=(224, 224, 3)) features = extract_features([inputs]) x = base_model.vit(inputs).last_hidden_state[:, 0] x = tf.keras.layers.Dense(256, activation='relu')(x) x = tf.keras.layers.Dropout(0.5)(x) outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x) model = tf.keras.Model(inputs, outputs) model.summary() # Compile the model model.compile(optimizer='adam', loss='binary_crossentropy', # Change loss function based on the number of classes metrics=['accuracy']) # Train the model if st.button("Train Model"): with st.spinner("Training the model..."): history = model.fit( ds_train, epochs=epochs, validation_data=ds_val ) st.success("Model training completed!") # Display training curves st.subheader("Training and Validation Accuracy") fig, ax = plt.subplots() ax.plot(history.history['accuracy'], label='Training Accuracy') ax.plot(history.history['val_accuracy'], label='Validation Accuracy') ax.set_xlabel('Epoch') ax.set_ylabel('Accuracy') ax.legend() st.pyplot(fig) st.subheader("Training and Validation Loss") fig, ax = plt.subplots() ax.plot(history.history['loss'], label='Training Loss') ax.plot(history.history['val_loss'], label='Validation Loss') ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.legend() st.pyplot(fig) # Evaluate the model if st.button("Evaluate Model"): test_loss, test_acc = model.evaluate(ds_val, verbose=2) st.write(f"Validation accuracy: {test_acc}")