eaglelandsonce commited on
Commit
ef623be
1 Parent(s): 516b01e

Update pages/15_TransferLearning_HF.py

Browse files
Files changed (1) hide show
  1. pages/15_TransferLearning_HF.py +10 -5
pages/15_TransferLearning_HF.py CHANGED
@@ -29,14 +29,20 @@ model_name = "google/vit-base-patch16-224-in21k"
29
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
30
  base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
31
 
32
- # Freeze the convolutional base
33
  base_model.trainable = False
34
 
 
 
 
 
 
 
 
35
  # Add custom layers on top
36
  inputs = tf.keras.Input(shape=(224, 224, 3))
37
- features = feature_extractor(inputs, return_tensors="tf")["pixel_values"]
38
- base_output = base_model(features)[0] # Extract base model output
39
- x = tf.keras.layers.Flatten()(base_output)
40
  x = tf.keras.layers.Dense(256, activation='relu')(x)
41
  x = tf.keras.layers.Dropout(0.5)(x)
42
  outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
@@ -83,4 +89,3 @@ if st.button("Train Model"):
83
  if st.button("Evaluate Model"):
84
  test_loss, test_acc = model.evaluate(ds_val, verbose=2)
85
  st.write(f"Validation accuracy: {test_acc}")
86
-
 
29
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
30
  base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
31
 
32
+ # Freeze the base model
33
  base_model.trainable = False
34
 
35
+ # Function to extract features using the feature extractor
36
+ def extract_features(images):
37
+ # Convert images to the expected format for the feature extractor
38
+ images = [tf.image.convert_image_dtype(image, tf.float32) for image in images]
39
+ inputs = feature_extractor(images, return_tensors="tf")
40
+ return inputs["pixel_values"]
41
+
42
  # Add custom layers on top
43
  inputs = tf.keras.Input(shape=(224, 224, 3))
44
+ features = extract_features([inputs])
45
+ x = base_model.vit(inputs).last_hidden_state[:, 0]
 
46
  x = tf.keras.layers.Dense(256, activation='relu')(x)
47
  x = tf.keras.layers.Dropout(0.5)(x)
48
  outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
 
89
  if st.button("Evaluate Model"):
90
  test_loss, test_acc = model.evaluate(ds_val, verbose=2)
91
  st.write(f"Validation accuracy: {test_acc}")