File size: 2,184 Bytes
d281558 db2ade3 3529fc7 c4706c1 c241385 d281558 d180ba7 1b5445d 3529fc7 1b5445d d281558 c4706c1 1b5445d d281558 1b5445d d281558 1b5445d e23da42 b78f591 1b5445d 3529fc7 d281558 19acbc1 d281558 c4706c1 d281558 f1a1c4b 8ff623d d281558 f1a1c4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
# Load the trained model
model_path = "pokemon-model_transferlearning1.keras"
model = tf.keras.models.load_model(model_path)
# Define the core prediction function
def predict_pokemon(image):
# Preprocess image
image = image.resize((150, 150)) # Resize the image to 150x150
image = image.convert('RGB') # Ensure image has 3 channels
image = np.array(image)
image = np.expand_dims(image, axis=0) # Add batch dimension
# Predict
prediction = model.predict(image)
# Apply softmax to get probabilities for each class
probabilities = tf.nn.softmax(prediction, axis=1)
# Map probabilities to Pokemon classes
class_names = ['Chansey', 'Growlithe', 'Lapras']
probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(class_names, probabilities.numpy()[0])}
return probabilities_dict
# Streamlit interface
st.title("Pokemon Classifier")
st.write("Welches Pokemon hast du ausgewählt?")
# Upload image
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"])
if uploaded_image is not None:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("")
st.write("Classifying...")
predictions = predict_pokemon(image)
# Display predictions as a DataFrame
st.write("### Prediction Probabilities")
df = pd.DataFrame(predictions.items(), columns=["Pokemon", "Probability"])
st.dataframe(df)
# Display predictions as a bar chart
st.write("### Prediction Chart")
fig, ax = plt.subplots()
ax.barh(df["Pokemon"], df["Probability"], color='skyblue')
ax.set_xlim(0, 1)
ax.set_xlabel('Probability')
ax.set_title('Prediction Probabilities')
st.pyplot(fig)
# Example images
st.sidebar.title("Examples")
example_images = ["pokemon/00000000.png","pokemon/00000001.png","pokemon/00000002.png"]
for example_image in example_images:
st.sidebar.image(example_image, use_column_width=True)
|