Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
from PIL import Image | |
import requests | |
import torch | |
import matplotlib.pyplot as plt | |
# Load the pre-trained model and processor | |
checkpoint = "openai/clip-vit-large-patch14" | |
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint) | |
processor = AutoProcessor.from_pretrained(checkpoint) | |
# Load and display the image | |
url = "https://cdn.akamai.steamstatic.com/steam/apps/1026420/header.jpg?t=1657716289" | |
image = Image.open(requests.get(url, stream=True).raw) | |
plt.imshow(image) | |
plt.show() | |
# Specify candidate labels for zero-shot classification | |
candidate_labels = ["tree", "car", "bike", "cat"] | |
# Prepare inputs for the model | |
inputs = processor(text=candidate_labels, images=image, return_tensors="pt", padding=True) | |
# Make predictions | |
outputs = model(**inputs) | |
logits = outputs.logits_per_image # shape: [batch_size, num_classes] | |
probs = logits.softmax(dim=1) # Convert to probabilities | |
# Get and print the most likely class | |
predicted_class_idx = probs.argmax(-1).item() | |
predicted_class = candidate_labels[predicted_class_idx] | |
print(f'Predicted class: {predicted_class}') | |