skshmjn commited on
Commit
35c2b13
·
verified ·
1 Parent(s): a0320ae

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -7,18 +7,18 @@ A fine-tuned version of ViT-base on a collected set of Pokémon images. You can
7
  # Using the model
8
 
9
  ```python
10
- from transformers import ViTForImageClassification, ViTFeatureExtractor
11
  from PIL import Image
12
  import torch
13
 
14
  # Loading in Model
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model = ViTForImageClassification.from_pretrained( "imjeffhi/pokemon_classifier").to(device)
17
- feature_extractor = ViTFeatureExtractor.from_pretrained('imjeffhi/pokemon_classifier')
18
 
19
  # Caling the model on a test image
20
  img = Image.open('test.jpg')
21
- extracted = feature_extractor(images=img, return_tensors='pt').to(device)
22
  predicted_id = model(**extracted).logits.argmax(-1).item()
23
  predicted_pokemon = model.config.id2label[predicted_id]
24
  ```
 
7
  # Using the model
8
 
9
  ```python
10
+ from transformers import ViTForImageClassification, ViTImageProcessor
11
  from PIL import Image
12
  import torch
13
 
14
  # Loading in Model
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model = ViTForImageClassification.from_pretrained( "imjeffhi/pokemon_classifier").to(device)
17
+ image_processor = ViTImageProcessor.from_pretrained('imjeffhi/pokemon_classifier')
18
 
19
  # Caling the model on a test image
20
  img = Image.open('test.jpg')
21
+ extracted = image_processor(images=img, return_tensors='pt').to(device)
22
  predicted_id = model(**extracted).logits.argmax(-1).item()
23
  predicted_pokemon = model.config.id2label[predicted_id]
24
  ```