File size: 3,237 Bytes
f72b662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef623be
f72b662
 
ef623be
 
 
 
 
 
 
f72b662
 
ef623be
 
f72b662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
import tensorflow as tf
from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Load the dataset
dataset_name = "cats_vs_dogs"
(ds_train, ds_val), ds_info = tfds.load(dataset_name, split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True)

# Preprocess the dataset
def preprocess_image(image, label):
    image = tf.image.resize(image, (224, 224))  # ViT requires 224x224 images
    image = image / 255.0
    return image, label

ds_train = ds_train.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)
ds_val = ds_val.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)

# Streamlit app
st.title("Transfer Learning with Vision Transformer for Image Classification")

# Input parameters
batch_size = st.slider("Batch Size", 16, 128, 32, 16)
epochs = st.slider("Epochs", 5, 50, 10, 5)

# Load the pre-trained Vision Transformer model
model_name = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2)  # Cats vs Dogs has 2 classes

# Freeze the base model
base_model.trainable = False

# Function to extract features using the feature extractor
def extract_features(images):
    # Convert images to the expected format for the feature extractor
    images = [tf.image.convert_image_dtype(image, tf.float32) for image in images]
    inputs = feature_extractor(images, return_tensors="tf")
    return inputs["pixel_values"]

# Add custom layers on top
inputs = tf.keras.Input(shape=(224, 224, 3))
features = extract_features([inputs])
x = base_model.vit(inputs).last_hidden_state[:, 0]
x = tf.keras.layers.Dense(256, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
model = tf.keras.Model(inputs, outputs)

model.summary()

# Compile the model
model.compile(optimizer='adam',
              loss='binary_crossentropy',  # Change loss function based on the number of classes
              metrics=['accuracy'])

# Train the model
if st.button("Train Model"):
    with st.spinner("Training the model..."):
        history = model.fit(
            ds_train,
            epochs=epochs,
            validation_data=ds_val
        )
        
        st.success("Model training completed!")

    # Display training curves
    st.subheader("Training and Validation Accuracy")
    fig, ax = plt.subplots()
    ax.plot(history.history['accuracy'], label='Training Accuracy')
    ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.legend()
    st.pyplot(fig)

    st.subheader("Training and Validation Loss")
    fig, ax = plt.subplots()
    ax.plot(history.history['loss'], label='Training Loss')
    ax.plot(history.history['val_loss'], label='Validation Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    st.pyplot(fig)

# Evaluate the model
if st.button("Evaluate Model"):
    test_loss, test_acc = model.evaluate(ds_val, verbose=2)
    st.write(f"Validation accuracy: {test_acc}")