mansesa3 commited on
Commit
1b5445d
·
verified ·
1 Parent(s): c71507b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -1,32 +1,40 @@
1
- import os
2
- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
3
  import gradio as gr
4
  import tensorflow as tf
5
  import numpy as np
6
  from PIL import Image
7
 
 
8
 
9
- # Modell laden
10
- model = tf.keras.models.load_model('pokemon_classifier_model.h5')
11
 
12
- class_names = ['Chansey', 'Growlithe', 'Lapras']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def predict(image):
15
- image = image.resize((150, 150))
16
- img_array = tf.keras.preprocessing.image.img_to_array(image)
17
- img_array = np.expand_dims(img_array, axis=0)
18
- predictions = model.predict(img_array)
19
- score = tf.nn.softmax(predictions[0])
20
- return {class_names[i]: float(score[i]) for i in range(3)}
21
-
22
- # Gradio Interface
23
- interface = gr.Interface(
24
- fn=predict,
25
- inputs=gr.inputs.Image(shape=(150, 150)),
26
- outputs=gr.outputs.Label(num_top_classes=3),
27
- title="Pokémon Classifier",
28
- description="Upload an image of Chansey, Growlithe, or Lapras"
29
- )
30
-
31
- if __name__ == "__main__":
32
- interface.launch()
 
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ #!pip install tensorflow tensorflow-datasets gradio pillow matplotlib
7
 
8
+ model_path = "pokemon-model_transferlearning.keras"
9
+ model = tf.keras.models.load_model(model_path)
10
 
11
+ # Define the core prediction function
12
+ def predict_pokemon(image):
13
+ # Preprocess image
14
+ image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
15
+ image = image.resize((150, 150)) # Resize the image to 150x150
16
+ image = np.array(image)
17
+ image = np.expand_dims(image, axis=0) # Add batch dimension
18
+
19
+ # Predict
20
+ prediction = model.predict(image)
21
+
22
+ # Apply softmax to get probabilities for each class
23
+ probabilities = tf.nn.softmax(prediction)
24
+
25
+ # Map probabilities to Pokemon classes
26
+ pokemon_classes = ['Articuno', 'Bulbasaur', 'Charmander']
27
+ probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(pokemon_classes, probabilities[0])}
28
+
29
+ return probabilities_dict
30
 
31
+ # Create the Gradio interface
32
+ input_image = gr.Image()
33
+ iface = gr.Interface(
34
+ fn=predict_pokemon,
35
+ inputs=input_image,
36
+ outputs=gr.Label(),
37
+ live=True,
38
+ examples=["images/01.jpg", "images/02.png", "images/03.png", "images/04.jpg", "images/06.png", "images/06.png"],
39
+ description="A simple mlp classification model for image classification using the mnist dataset.")
40
+ iface.launch()