Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Commit
•
ef623be
1
Parent(s):
516b01e
Update pages/15_TransferLearning_HF.py
Browse files
pages/15_TransferLearning_HF.py
CHANGED
@@ -29,14 +29,20 @@ model_name = "google/vit-base-patch16-224-in21k"
|
|
29 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
30 |
base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
|
31 |
|
32 |
-
# Freeze the
|
33 |
base_model.trainable = False
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
# Add custom layers on top
|
36 |
inputs = tf.keras.Input(shape=(224, 224, 3))
|
37 |
-
features =
|
38 |
-
|
39 |
-
x = tf.keras.layers.Flatten()(base_output)
|
40 |
x = tf.keras.layers.Dense(256, activation='relu')(x)
|
41 |
x = tf.keras.layers.Dropout(0.5)(x)
|
42 |
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
|
@@ -83,4 +89,3 @@ if st.button("Train Model"):
|
|
83 |
if st.button("Evaluate Model"):
|
84 |
test_loss, test_acc = model.evaluate(ds_val, verbose=2)
|
85 |
st.write(f"Validation accuracy: {test_acc}")
|
86 |
-
|
|
|
29 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
30 |
base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
|
31 |
|
32 |
+
# Freeze the base model
|
33 |
base_model.trainable = False
|
34 |
|
35 |
+
# Function to extract features using the feature extractor
|
36 |
+
def extract_features(images):
|
37 |
+
# Convert images to the expected format for the feature extractor
|
38 |
+
images = [tf.image.convert_image_dtype(image, tf.float32) for image in images]
|
39 |
+
inputs = feature_extractor(images, return_tensors="tf")
|
40 |
+
return inputs["pixel_values"]
|
41 |
+
|
42 |
# Add custom layers on top
|
43 |
inputs = tf.keras.Input(shape=(224, 224, 3))
|
44 |
+
features = extract_features([inputs])
|
45 |
+
x = base_model.vit(inputs).last_hidden_state[:, 0]
|
|
|
46 |
x = tf.keras.layers.Dense(256, activation='relu')(x)
|
47 |
x = tf.keras.layers.Dropout(0.5)(x)
|
48 |
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
|
|
|
89 |
if st.button("Evaluate Model"):
|
90 |
test_loss, test_acc = model.evaluate(ds_val, verbose=2)
|
91 |
st.write(f"Validation accuracy: {test_acc}")
|
|