import torch from transformers import ViTFeatureExtractor from config import UNTRAINED feature_extractor = ViTFeatureExtractor.from_pretrained(UNTRAINED) def predict(model, image): inputs = feature_extractor(image, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits # model predicts one of the 1000 ImageNet classes predicted_label = logits.argmax(-1).item() return model.config.id2label[str(predicted_label)]