|
# Import necessary libraries |
|
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification |
|
from PIL import Image |
|
import requests |
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
# Load the pre-trained model and processor |
|
checkpoint = "openai/clip-vit-large-patch14" |
|
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint) |
|
processor = AutoProcessor.from_pretrained(checkpoint) |
|
|
|
# Load and display the image |
|
url = "https://cdn.akamai.steamstatic.com/steam/apps/1026420/header.jpg?t=1657716289" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
plt.imshow(image) |
|
plt.show() |
|
|
|
# Specify candidate labels for zero-shot classification |
|
candidate_labels = ["tree", "car", "bike", "cat"] |
|
|
|
# Prepare inputs for the model |
|
inputs = processor(text=candidate_labels, images=image, return_tensors="pt", padding=True) |
|
|
|
# Make predictions |
|
outputs = model(**inputs) |
|
logits = outputs.logits_per_image # shape: [batch_size, num_classes] |
|
probs = logits.softmax(dim=1) # Convert to probabilities |
|
|
|
# Get and print the most likely class |
|
predicted_class_idx = probs.argmax(-1).item() |
|
predicted_class = candidate_labels[predicted_class_idx] |
|
print(f'Predicted class: {predicted_class}') |
|
|