File size: 2,376 Bytes
77cb8f3 1b9c6ed 77cb8f3 1b9c6ed d183073 f0c3a98 d183073 1b9c6ed 77cb8f3 1b9c6ed 77cb8f3 1b9c6ed 3a5138a 77cb8f3 1b9c6ed 77cb8f3 da1e06d 77cb8f3 1b9c6ed 77cb8f3 1b9c6ed 77cb8f3 1b9c6ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import numpy as np
import torch
# Load the model and feature extractor
model_name = "imjeffhi/pokemon_classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Define the Pokémon labels
# Although, labels are set with 3 Pokémon -> One can leverage the remaining 715 Pokémon
# Provides the functionality -> But, the hardcoded list of three Pokémon, is not correct.
# On the next variant [mutation] of the Pokémon classifier jettison obverse the actual model's configuration
# yielding mapping from class indices to labels. Then, foment the model's predictions on the Pokémon.
labels = ['Jolteon', 'Kakuna', 'Mr. Mime']
# Function to preprocess the image
def preprocess_image(img_pil):
inputs = feature_extractor(images=img_pil, return_tensors="pt")
return inputs
# Function to predict the class of the image
def predict_classification(img_pil):
inputs = preprocess_image(img_pil)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Check if the predicted class index is within the valid range of the labels list
if predicted_class_idx < len(labels):
predicted_class = labels[predicted_class_idx]
else:
predicted_class = "Unknown" # Default to "Unknown" if the index is out of range
confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
return predicted_class, confidence
# Function to handle the prediction in the Gradio interface
def gradio_predict(img_pil):
predicted_class, confidence = predict_classification(img_pil)
return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}"
# Create Gradio interface
input_image = gr.Image(label="Upload an image of a Pokemon")
output_text = gr.Textbox(label="Predicted Class and Confidence")
iface = gr.Interface(
fn=gradio_predict,
inputs=input_image,
outputs=output_text,
title="Pokemon Classifier",
description="Upload an image of a Pokemon and the classifier will tell you which one it is and the confidence level of the prediction.",
allow_flagging="never"
)
iface.launch() |