File size: 465 Bytes
9506f43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
from transformers import ViTFeatureExtractor
from config import UNTRAINED
feature_extractor = ViTFeatureExtractor.from_pretrained(UNTRAINED)
def predict(model, image):
inputs = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
return model.config.id2label[str(predicted_label)] |