nielsr HF staff commited on
Commit
8140d7d
1 Parent(s): 5abdadb

Add FLAX code example

Browse files
Files changed (1) hide show
  1. README.md +21 -2
README.md CHANGED
@@ -28,22 +28,41 @@ fine-tuned versions on a task that interests you.
28
 
29
  ### How to use
30
 
31
- Here is how to use this model:
32
 
33
  ```python
34
  from transformers import ViTFeatureExtractor, ViTModel
35
  from PIL import Image
36
  import requests
 
37
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
38
  image = Image.open(requests.get(url, stream=True).raw)
 
39
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
40
  model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
41
  inputs = feature_extractor(images=image, return_tensors="pt")
 
42
  outputs = model(**inputs)
43
  last_hidden_states = outputs.last_hidden_state
44
  ```
45
 
46
- Currently, both the feature extractor and model support PyTorch. Tensorflow and JAX/FLAX are coming soon, and the API of ViTFeatureExtractor might change.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  ## Training data
49
 
 
28
 
29
  ### How to use
30
 
31
+ Here is how to use this model in PyTorch:
32
 
33
  ```python
34
  from transformers import ViTFeatureExtractor, ViTModel
35
  from PIL import Image
36
  import requests
37
+
38
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
39
  image = Image.open(requests.get(url, stream=True).raw)
40
+
41
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
42
  model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
43
  inputs = feature_extractor(images=image, return_tensors="pt")
44
+
45
  outputs = model(**inputs)
46
  last_hidden_states = outputs.last_hidden_state
47
  ```
48
 
49
+ Here is how to use this model in JAX/Flax:
50
+
51
+ ```python
52
+ from transformers import ViTFeatureExtractor, FlaxViTModel
53
+ from PIL import Image
54
+ import requests
55
+
56
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
57
+ image = Image.open(requests.get(url, stream=True).raw)
58
+
59
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
60
+ model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
61
+
62
+ inputs = feature_extractor(images=image, return_tensors="np")
63
+ outputs = model(**inputs)
64
+ last_hidden_states = outputs.last_hidden_state
65
+ ```
66
 
67
  ## Training data
68