import gradio as gr from transformers import AutoModelForImageClassification, AutoFeatureExtractor from PIL import Image import numpy as np import torch # Load the model and feature extractor model_name = "imjeffhi/pokemon_classifier" model = AutoModelForImageClassification.from_pretrained(model_name) feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) # Define the Pokémon labels # Although, labels are set with 3 Pokémon -> One can leverage the remaining 715 Pokémon # Provides the functionality -> But, the hardcoded list of three Pokémon, is not correct. # On the next variant [mutation] of the Pokémon classifier jettison obverse the actual model's configuration # yielding mapping from class indices to labels. Then, foment the model's predictions on the Pokémon. labels = ['Jolteon', 'Kakuna', 'Mr. Mime'] # Function to preprocess the image def preprocess_image(img_pil): inputs = feature_extractor(images=img_pil, return_tensors="pt") return inputs # Function to predict the class of the image def predict_classification(img_pil): inputs = preprocess_image(img_pil) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # Check if the predicted class index is within the valid range of the labels list if predicted_class_idx < len(labels): predicted_class = labels[predicted_class_idx] else: predicted_class = "Unknown" # Default to "Unknown" if the index is out of range confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx] return predicted_class, confidence # Function to handle the prediction in the Gradio interface def gradio_predict(img_pil): predicted_class, confidence = predict_classification(img_pil) return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}" # Create Gradio interface input_image = gr.Image(label="Upload an image of a Pokemon") output_text = gr.Textbox(label="Predicted Class and Confidence") iface = gr.Interface( fn=gradio_predict, inputs=input_image, outputs=output_text, title="Pokemon Classifier", description="Upload an image of a Pokemon and the classifier will tell you which one it is and the confidence level of the prediction.", allow_flagging="never" ) iface.launch()