Image Classification
TF-Keras
ariG23498 HF staff commited on
Commit
e86f0df
·
1 Parent(s): dd0a668

feat: add the process of using the pre-trained model

Browse files

refering to https://github.com/huggingface/huggingface_hub/issues/595

Files changed (1) hide show
  1. README.md +50 -0
README.md CHANGED
@@ -20,3 +20,53 @@ The main ideas are:
20
 
21
  - Shifted Patch Tokenization
22
  - Locality Self Attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  - Shifted Patch Tokenization
22
  - Locality Self Attention
23
+
24
+ # Use the pre-trained model
25
+
26
+ The model is pre-trained on the CIFAR100 dataset with the following hyperparameters:
27
+ ```python
28
+ # DATA
29
+ NUM_CLASSES = 100
30
+ INPUT_SHAPE = (32, 32, 3)
31
+ BUFFER_SIZE = 512
32
+ BATCH_SIZE = 256
33
+
34
+ # AUGMENTATION
35
+ IMAGE_SIZE = 72
36
+ PATCH_SIZE = 6
37
+ NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
38
+
39
+ # OPTIMIZER
40
+ LEARNING_RATE = 0.001
41
+ WEIGHT_DECAY = 0.0001
42
+
43
+ # TRAINING
44
+ EPOCHS = 50
45
+
46
+ # ARCHITECTURE
47
+ LAYER_NORM_EPS = 1e-6
48
+ TRANSFORMER_LAYERS = 8
49
+ PROJECTION_DIM = 64
50
+ NUM_HEADS = 4
51
+ TRANSFORMER_UNITS = [
52
+ PROJECTION_DIM * 2,
53
+ PROJECTION_DIM,
54
+ ]
55
+ MLP_HEAD_UNITS = [
56
+ 2048,
57
+ 1024
58
+ ]
59
+ ```
60
+ I have used the `AdamW` optimizer with cosine decay learning schedule. You can find the entire implementation in the keras blog post.
61
+
62
+ To use the pretrained model:
63
+ ```python
64
+ loaded_model = from_pretrained_keras("keras-io/vit-small-ds")
65
+ _, accuracy, top_5_accuracy = loaded_model.evaluate(test_ds)
66
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
67
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
68
+ ```
69
+
70
+ For an indepth understanding of the model uploading and downloading process one can refer to this [colab notebook](https://colab.research.google.com/drive/1nCMhefqySzG2p8wyXhmeAX5urddQXt49?usp=sharing).
71
+
72
+ Important: The data augmentation pipeline is excluded from the model. TensorFlow `2.7` has a weird issue of serializaiton with augmentation pipeline. You can follow [this GitHub issue](https://github.com/huggingface/huggingface_hub/issues/593) for more updates. To send images through the model, one needs to make use of the `tf.data` and `map` API to map the augmentation.