File size: 465 Bytes
9506f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from transformers import ViTFeatureExtractor
from config import UNTRAINED

feature_extractor = ViTFeatureExtractor.from_pretrained(UNTRAINED)

def predict(model, image):
    inputs = feature_extractor(image, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits

    # model predicts one of the 1000 ImageNet classes
    predicted_label = logits.argmax(-1).item()
    return model.config.id2label[str(predicted_label)]