Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification | |
import tensorflow_datasets as tfds | |
import matplotlib.pyplot as plt | |
# Load the dataset | |
dataset_name = "cats_vs_dogs" | |
(ds_train, ds_val), ds_info = tfds.load(dataset_name, split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True) | |
# Preprocess the dataset | |
def preprocess_image(image, label): | |
image = tf.image.resize(image, (224, 224)) # ViT requires 224x224 images | |
image = image / 255.0 | |
return image, label | |
ds_train = ds_train.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE) | |
ds_val = ds_val.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE) | |
# Streamlit app | |
st.title("Transfer Learning with Vision Transformer for Image Classification") | |
# Input parameters | |
batch_size = st.slider("Batch Size", 16, 128, 32, 16) | |
epochs = st.slider("Epochs", 5, 50, 10, 5) | |
# Load the pre-trained Vision Transformer model | |
model_name = "google/vit-base-patch16-224-in21k" | |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes | |
# Freeze the base model | |
base_model.trainable = False | |
# Function to extract features using the feature extractor | |
def extract_features(images): | |
# Convert images to the expected format for the feature extractor | |
images = [tf.image.convert_image_dtype(image, tf.float32) for image in images] | |
inputs = feature_extractor(images, return_tensors="tf") | |
return inputs["pixel_values"] | |
# Add custom layers on top | |
inputs = tf.keras.Input(shape=(224, 224, 3)) | |
features = extract_features([inputs]) | |
x = base_model.vit(inputs).last_hidden_state[:, 0] | |
x = tf.keras.layers.Dense(256, activation='relu')(x) | |
x = tf.keras.layers.Dropout(0.5)(x) | |
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x) | |
model = tf.keras.Model(inputs, outputs) | |
model.summary() | |
# Compile the model | |
model.compile(optimizer='adam', | |
loss='binary_crossentropy', # Change loss function based on the number of classes | |
metrics=['accuracy']) | |
# Train the model | |
if st.button("Train Model"): | |
with st.spinner("Training the model..."): | |
history = model.fit( | |
ds_train, | |
epochs=epochs, | |
validation_data=ds_val | |
) | |
st.success("Model training completed!") | |
# Display training curves | |
st.subheader("Training and Validation Accuracy") | |
fig, ax = plt.subplots() | |
ax.plot(history.history['accuracy'], label='Training Accuracy') | |
ax.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
ax.set_xlabel('Epoch') | |
ax.set_ylabel('Accuracy') | |
ax.legend() | |
st.pyplot(fig) | |
st.subheader("Training and Validation Loss") | |
fig, ax = plt.subplots() | |
ax.plot(history.history['loss'], label='Training Loss') | |
ax.plot(history.history['val_loss'], label='Validation Loss') | |
ax.set_xlabel('Epoch') | |
ax.set_ylabel('Loss') | |
ax.legend() | |
st.pyplot(fig) | |
# Evaluate the model | |
if st.button("Evaluate Model"): | |
test_loss, test_acc = model.evaluate(ds_val, verbose=2) | |
st.write(f"Validation accuracy: {test_acc}") | |