File size: 2,184 Bytes
d281558
db2ade3
 
3529fc7
c4706c1
 
c241385
d281558
d180ba7
1b5445d
3529fc7
1b5445d
 
 
d281558
c4706c1
1b5445d
d281558
1b5445d
 
 
 
 
d281558
1b5445d
 
e23da42
b78f591
1b5445d
 
3529fc7
d281558
 
19acbc1
d281558
 
 
 
 
 
 
 
 
 
 
 
c4706c1
 
 
 
 
 
 
 
 
 
 
 
 
d281558
 
f1a1c4b
8ff623d
d281558
f1a1c4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

# Load the trained model
model_path = "pokemon-model_transferlearning1.keras"
model = tf.keras.models.load_model(model_path)

# Define the core prediction function
def predict_pokemon(image):
    # Preprocess image
    image = image.resize((150, 150))  # Resize the image to 150x150
    image = image.convert('RGB')  # Ensure image has 3 channels
    image = np.array(image)
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    
    # Predict
    prediction = model.predict(image)
    
    # Apply softmax to get probabilities for each class
    probabilities = tf.nn.softmax(prediction, axis=1)
    
    # Map probabilities to Pokemon classes
    class_names = ['Chansey', 'Growlithe', 'Lapras']
    probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(class_names, probabilities.numpy()[0])}
    
    return probabilities_dict

# Streamlit interface
st.title("Pokemon Classifier")
st.write("Welches Pokemon hast du ausgewählt?")

# Upload image
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"])

if uploaded_image is not None:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image.', use_column_width=True)
    st.write("")
    st.write("Classifying...")
    
    predictions = predict_pokemon(image)
    
    # Display predictions as a DataFrame
    st.write("### Prediction Probabilities")
    df = pd.DataFrame(predictions.items(), columns=["Pokemon", "Probability"])
    st.dataframe(df)
    
    # Display predictions as a bar chart
    st.write("### Prediction Chart")
    fig, ax = plt.subplots()
    ax.barh(df["Pokemon"], df["Probability"], color='skyblue')
    ax.set_xlim(0, 1)
    ax.set_xlabel('Probability')
    ax.set_title('Prediction Probabilities')
    st.pyplot(fig)

# Example images
st.sidebar.title("Examples")
example_images = ["pokemon/00000000.png","pokemon/00000001.png","pokemon/00000002.png"]
for example_image in example_images:
    st.sidebar.image(example_image, use_column_width=True)