mansesa3 commited on
Commit
d281558
·
verified ·
1 Parent(s): 93ccd32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -1,40 +1,50 @@
1
- import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- #!pip install tensorflow tensorflow-datasets gradio pillow matplotlib
7
-
8
  model_path = "pokemon-model_transferlearning.keras"
9
  model = tf.keras.models.load_model(model_path)
10
 
11
  # Define the core prediction function
12
  def predict_pokemon(image):
13
  # Preprocess image
14
- image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
15
- image = image.resize((150, 150)) # Resize the image to 150x150
16
  image = np.array(image)
17
- image = np.expand_dims(image, axis=0) # Add batch dimension
18
 
19
  # Predict
20
  prediction = model.predict(image)
21
 
22
  # Apply softmax to get probabilities for each class
23
- probabilities = tf.nn.softmax(prediction)
24
 
25
  # Map probabilities to Pokemon classes
26
  pokemon_classes = ['Articuno', 'Bulbasaur', 'Charmander']
27
- probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(pokemon_classes, probabilities[0])}
28
 
29
  return probabilities_dict
30
 
31
- # Create the Gradio interface
32
- input_image = gr.Image()
33
- iface = gr.Interface(
34
- fn=predict_pokemon,
35
- inputs=input_image,
36
- outputs=gr.Label(),
37
- live=True,
38
- examples=["images/01.jpg", "images/02.png", "images/03.png", "images/04.jpg", "images/06.png", "images/06.png"],
39
- description="A simple mlp classification model for image classification using the mnist dataset.")
40
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # Load the trained model
 
7
  model_path = "pokemon-model_transferlearning.keras"
8
  model = tf.keras.models.load_model(model_path)
9
 
10
  # Define the core prediction function
11
  def predict_pokemon(image):
12
  # Preprocess image
13
+ image = image.resize((150, 150)) # Resize the image to 150x150
 
14
  image = np.array(image)
15
+ image = np.expand_dims(image, axis=0) # Add batch dimension
16
 
17
  # Predict
18
  prediction = model.predict(image)
19
 
20
  # Apply softmax to get probabilities for each class
21
+ probabilities = tf.nn.softmax(prediction, axis=1)
22
 
23
  # Map probabilities to Pokemon classes
24
  pokemon_classes = ['Articuno', 'Bulbasaur', 'Charmander']
25
+ probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(pokemon_classes, probabilities.numpy()[0])}
26
 
27
  return probabilities_dict
28
 
29
+ # Streamlit interface
30
+ st.title("Pokemon Classifier")
31
+ st.write("A simple MLP classification model for image classification using a pretrained model.")
32
+
33
+ # Upload image
34
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"])
35
+
36
+ if uploaded_image is not None:
37
+ image = Image.open(uploaded_image)
38
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
39
+ st.write("")
40
+ st.write("Classifying...")
41
+
42
+ predictions = predict_pokemon(image)
43
+
44
+ st.write(predictions)
45
+
46
+ # Example images
47
+ st.sidebar.title("Examples")
48
+ example_images = ["images/01.jpg", "images/02.png", "images/03.png", "images/04.jpg", "images/05.png", "images/06.png"]
49
+ for example_image in example_images:
50
+ st.sidebar.image(example_image, use_column_width=True)