mansesa3 commited on
Commit
b78f591
·
verified ·
1 Parent(s): e23da42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -11,6 +11,7 @@ model = tf.keras.models.load_model(model_path)
11
  def predict_pokemon(image):
12
  # Preprocess image
13
  image = image.resize((150, 150)) # Resize the image to 150x150
 
14
  image = np.array(image)
15
  image = np.expand_dims(image, axis=0) # Add batch dimension
16
 
@@ -22,13 +23,13 @@ def predict_pokemon(image):
22
 
23
  # Map probabilities to Pokemon classes
24
  class_names = ['Chansey', 'Growlithe', 'Lapras']
25
- probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(pokemon_classes, probabilities.numpy()[0])}
26
 
27
  return probabilities_dict
28
 
29
  # Streamlit interface
30
  st.title("Pokemon Classifier")
31
- st.write("A simple MLP classification model for image classification using a pretrained model.")
32
 
33
  # Upload image
34
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"])
@@ -48,3 +49,4 @@ st.sidebar.title("Examples")
48
  example_images = ["pokemon/train/chansey/00000000.png", "pokemon/train/growlithe/00000000.png", "pokemon/train/lapras/00000000.png"]
49
  for example_image in example_images:
50
  st.sidebar.image(example_image, use_column_width=True)
 
 
11
  def predict_pokemon(image):
12
  # Preprocess image
13
  image = image.resize((150, 150)) # Resize the image to 150x150
14
+ image = image.convert('RGB') # Ensure the image is in RGB format
15
  image = np.array(image)
16
  image = np.expand_dims(image, axis=0) # Add batch dimension
17
 
 
23
 
24
  # Map probabilities to Pokemon classes
25
  class_names = ['Chansey', 'Growlithe', 'Lapras']
26
+ probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(class_names, probabilities.numpy()[0])}
27
 
28
  return probabilities_dict
29
 
30
  # Streamlit interface
31
  st.title("Pokemon Classifier")
32
+ st.write("Eine KI die Pokemons identifiziert :) Viel Spass")
33
 
34
  # Upload image
35
  uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"])
 
49
  example_images = ["pokemon/train/chansey/00000000.png", "pokemon/train/growlithe/00000000.png", "pokemon/train/lapras/00000000.png"]
50
  for example_image in example_images:
51
  st.sidebar.image(example_image, use_column_width=True)
52
+