File size: 1,009 Bytes
8d28437
 
 
e7ea0c7
8d28437
e7ea0c7
8d28437
 
e7ea0c7
8d28437
e7ea0c7
 
 
8d28437
e7ea0c7
 
 
 
 
8d28437
 
 
e7ea0c7
8d28437
e7ea0c7
 
 
8d28437
 
e7ea0c7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image

# Load your model
model = tf.keras.models.load_model("denis_mnist_cnn_model.h5")

# Define the prediction function
def predict(image):
    image = np.array(image)  # Convert to numpy array
    image = tf.image.resize(image, (224, 224))  # Resize to the model's expected input size
    image = np.expand_dims(image, axis=0)  # Add batch dimension (model expects a batch of images)
    image = image / 255.0  # Normalize pixel values
    
    # Check if the model needs flattening
    if len(image.shape) == 4:  # Check if image has a batch dimension
        image = tf.keras.layers.Flatten()(image)  # Flatten the image if necessary
    
    prediction = model.predict(image)
    return {"prediction": prediction.tolist()}

# Create the Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs="image",  # Image input
    outputs="json",  # Output as JSON
)

# Launch the interface
interface.launch(share=True)