File size: 2,376 Bytes
77cb8f3
1b9c6ed
 
 
77cb8f3
1b9c6ed
 
 
 
 
 
d183073
f0c3a98
d183073
 
 
1b9c6ed
 
 
77cb8f3
 
1b9c6ed
 
 
77cb8f3
 
 
 
1b9c6ed
 
3a5138a
 
 
 
 
 
 
77cb8f3
1b9c6ed
 
77cb8f3
da1e06d
77cb8f3
 
 
1b9c6ed
 
 
 
 
77cb8f3
1b9c6ed
 
 
77cb8f3
 
1b9c6ed
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import numpy as np
import torch

# Load the model and feature extractor
model_name = "imjeffhi/pokemon_classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Define the Pokémon labels
# Although, labels are set with 3 Pokémon -> One can leverage the remaining 715 Pokémon
# Provides the functionality -> But, the hardcoded list of three Pokémon, is not correct.
# On the next variant [mutation] of the Pokémon classifier jettison obverse the actual model's configuration
# yielding mapping from class indices to labels. Then, foment the model's predictions on the Pokémon.
labels = ['Jolteon', 'Kakuna', 'Mr. Mime']

# Function to preprocess the image
def preprocess_image(img_pil):
    inputs = feature_extractor(images=img_pil, return_tensors="pt")
    return inputs

# Function to predict the class of the image
def predict_classification(img_pil):
    inputs = preprocess_image(img_pil)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    
    # Check if the predicted class index is within the valid range of the labels list
    if predicted_class_idx < len(labels):
        predicted_class = labels[predicted_class_idx]
    else:
        predicted_class = "Unknown"  # Default to "Unknown" if the index is out of range
    
    confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
    return predicted_class, confidence

# Function to handle the prediction in the Gradio interface
def gradio_predict(img_pil):
    predicted_class, confidence = predict_classification(img_pil)
    return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}"

# Create Gradio interface
input_image = gr.Image(label="Upload an image of a Pokemon")
output_text = gr.Textbox(label="Predicted Class and Confidence")

iface = gr.Interface(
    fn=gradio_predict,
    inputs=input_image,
    outputs=output_text,
    title="Pokemon Classifier",
    description="Upload an image of a Pokemon and the classifier will tell you which one it is and the confidence level of the prediction.",
    allow_flagging="never"
)

iface.launch()