mansesa3 commited on
Commit
c4706c1
·
verified ·
1 Parent(s): b78f591

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -2,6 +2,8 @@ import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
 
 
5
 
6
  # Load the trained model
7
  model_path = "pokemon-model_transferlearning1.keras"
@@ -11,7 +13,7 @@ model = tf.keras.models.load_model(model_path)
11
  def predict_pokemon(image):
12
  # Preprocess image
13
  image = image.resize((150, 150)) # Resize the image to 150x150
14
- image = image.convert('RGB') # Ensure the image is in RGB format
15
  image = np.array(image)
16
  image = np.expand_dims(image, axis=0) # Add batch dimension
17
 
@@ -29,7 +31,7 @@ def predict_pokemon(image):
29
 
30
  # Streamlit interface
31
  st.title("Pokemon Classifier")
32
- st.write("Eine KI die Pokemons identifiziert :) Viel Spass")
33
 
34
  # Upload image
35
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"])
@@ -42,11 +44,22 @@ if uploaded_image is not None:
42
 
43
  predictions = predict_pokemon(image)
44
 
45
- st.write(predictions)
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Example images
48
  st.sidebar.title("Examples")
49
  example_images = ["pokemon/train/chansey/00000000.png", "pokemon/train/growlithe/00000000.png", "pokemon/train/lapras/00000000.png"]
50
  for example_image in example_images:
51
  st.sidebar.image(example_image, use_column_width=True)
52
-
 
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
 
8
  # Load the trained model
9
  model_path = "pokemon-model_transferlearning1.keras"
 
13
  def predict_pokemon(image):
14
  # Preprocess image
15
  image = image.resize((150, 150)) # Resize the image to 150x150
16
+ image = image.convert('RGB') # Ensure image has 3 channels
17
  image = np.array(image)
18
  image = np.expand_dims(image, axis=0) # Add batch dimension
19
 
 
31
 
32
  # Streamlit interface
33
  st.title("Pokemon Classifier")
34
+ st.write("A simple MLP classification model for image classification using a pretrained model.")
35
 
36
  # Upload image
37
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"])
 
44
 
45
  predictions = predict_pokemon(image)
46
 
47
+ # Display predictions as a DataFrame
48
+ st.write("### Prediction Probabilities")
49
+ df = pd.DataFrame(predictions.items(), columns=["Pokemon", "Probability"])
50
+ st.dataframe(df)
51
+
52
+ # Display predictions as a bar chart
53
+ st.write("### Prediction Chart")
54
+ fig, ax = plt.subplots()
55
+ ax.barh(df["Pokemon"], df["Probability"], color='skyblue')
56
+ ax.set_xlim(0, 1)
57
+ ax.set_xlabel('Probability')
58
+ ax.set_title('Prediction Probabilities')
59
+ st.pyplot(fig)
60
 
61
  # Example images
62
  st.sidebar.title("Examples")
63
  example_images = ["pokemon/train/chansey/00000000.png", "pokemon/train/growlithe/00000000.png", "pokemon/train/lapras/00000000.png"]
64
  for example_image in example_images:
65
  st.sidebar.image(example_image, use_column_width=True)