hb-setosys commited on
Commit
3a813cf
·
verified ·
1 Parent(s): e7ea0c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -3,28 +3,36 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # Load your model
7
  model = tf.keras.models.load_model("denis_mnist_cnn_model.h5")
8
 
9
- # Define the prediction function
10
  def predict(image):
11
- image = np.array(image) # Convert to numpy array
12
- image = tf.image.resize(image, (224, 224)) # Resize to the model's expected input size
13
- image = np.expand_dims(image, axis=0) # Add batch dimension (model expects a batch of images)
14
- image = image / 255.0 # Normalize pixel values
15
 
16
- # Check if the model needs flattening
17
- if len(image.shape) == 4: # Check if image has a batch dimension
18
- image = tf.keras.layers.Flatten()(image) # Flatten the image if necessary
 
19
 
 
 
 
 
 
 
 
20
  prediction = model.predict(image)
 
 
21
  return {"prediction": prediction.tolist()}
22
 
23
- # Create the Gradio interface
24
  interface = gr.Interface(
25
- fn=predict,
26
- inputs="image", # Image input
27
- outputs="json", # Output as JSON
28
  )
29
 
30
  # Launch the interface
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # Load the model (ensure you have the correct model path)
7
  model = tf.keras.models.load_model("denis_mnist_cnn_model.h5")
8
 
9
+ # Define a function to preprocess input and make predictions
10
  def predict(image):
11
+ # Convert image to a numpy array
12
+ image = np.array(image)
 
 
13
 
14
+ # Resize the image to the expected shape (28, 28, 3)
15
+ image = tf.image.resize(image, (28, 28)) # Resize to 28x28 pixels
16
+ image = np.expand_dims(image, axis=-1) # Add the channel dimension if grayscale
17
+ image = np.repeat(image, 3, axis=-1) # Convert grayscale to RGB (if model was trained on RGB images)
18
 
19
+ # Normalize the image
20
+ image = image / 255.0
21
+
22
+ # Add batch dimension
23
+ image = np.expand_dims(image, axis=0) # Add batch dimension to match the model's expected input shape (1, 28, 28, 3)
24
+
25
+ # Perform prediction
26
  prediction = model.predict(image)
27
+
28
+ # Return prediction as JSON
29
  return {"prediction": prediction.tolist()}
30
 
31
+ # Create a Gradio interface
32
  interface = gr.Interface(
33
+ fn=predict,
34
+ inputs="image", # Image input for testing
35
+ outputs="json" # JSON output for prediction results
36
  )
37
 
38
  # Launch the interface