mnist / app.py
alibayram's picture
Refactor predict function: add debug prints for non-zero pixel values and update input sketchpad to grayscale mode
69bd373
raw
history blame
2.47 kB
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2
# App title
title = "Welcome to your first sketch recognition app!"
# App description
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"<p>The model is trained to classify numbers (from 0 to 9). "
"To test it, draw your number in the space provided.</p>"
"</center>"
)
# GitHub repository link
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
# Class names (from 0 to 9)
labels = {
0: "zero",
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
6: "six",
7: "seven",
8: "eight",
9: "nine"
}
# Load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
def predict(data):
# Convert to NumPy array
img = np.array(data['composite'])
# print non-zero values
print("non-zero values", np.count_nonzero(img))
for i in range(img.shape[0]):
for j in range(img.shape[1]):
for k in range(img.shape[2]):
if img[i][j][k] != 0:
print("img[i][j][k]", img[i][j][k])
print("img.shape", img.shape)
# Handle RGBA or RGB images
if img.shape[-1] == 4: # RGBA
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
if img.shape[-1] == 3: # RGB
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Resize image to 28x28
img = cv2.resize(img, (28, 28))
# Normalize pixel values to [0, 1]
img = img / 255.0
# Reshape to match model input
img = img.reshape(1, 28, 28, 1)
print("img", img)
# Model predictions
preds = model.predict(img)[0]
print("preds", preds)
values_map = {preds[i]: i for i in range(len(preds))}
sorted_values = sorted(preds, reverse=True)
labels_map = dict()
for i in range(3):
print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]])
labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i]
print("labels_map", labels_map)
return labels_map
# Top 3 classes
label = gr.Label(num_top_classes=3)
# Open Gradio interface for sketch recognition
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(type='numpy', image_mode='L', brush=gr.Brush()),
outputs=label,
title=title,
description=head,
article=ref
)
interface.launch(share=True)