Plai / Simple Image Classification
Queue-Tip's picture
Update Simple Image Classification
ff070ed verified
raw
history blame contribute delete
961 Bytes
# Import necessary libraries
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
import torch
import matplotlib.pyplot as plt
# Load pre-trained feature extractor and model
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# Load and display the image
url = "https://cdn.akamai.steamstatic.com/steam/apps/821880/header.jpg?t=1652241767"
image = Image.open(requests.get(url, stream=True).raw)
plt.imshow(image)
plt.show()
# Extract features from the image
inputs = feature_extractor(images=image, return_tensors="pt")
# Make predictions
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Get and print the predicted class name
predicted_class = model.config.id2label[predicted_class_idx]
print(f'Predicted class: {predicted_class}')