File size: 2,466 Bytes
051f92c a49ba8d d4b4b25 cf0b1f5 a49ba8d e6fdb4c a49ba8d e6fdb4c e1b26df e6fdb4c d4b4b25 e1b26df e6fdb4c e1b26df d4b4b25 e6fdb4c 6c3c8f8 e6fdb4c c40b85e cf0b1f5 676005c 69bd373 676005c cf0b1f5 676005c cf0b1f5 892b132 d6bce58 6c3c8f8 363a74f 6c3c8f8 d6bce58 6c3c8f8 cf0b1f5 6c3c8f8 051f92c e6fdb4c c40b85e 051f92c 69bd373 e6fdb4c 051f92c c40b85e cf0b1f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2
# App title
title = "Welcome to your first sketch recognition app!"
# App description
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"<p>The model is trained to classify numbers (from 0 to 9). "
"To test it, draw your number in the space provided.</p>"
"</center>"
)
# GitHub repository link
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
# Class names (from 0 to 9)
labels = {
0: "zero",
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
6: "six",
7: "seven",
8: "eight",
9: "nine"
}
# Load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
def predict(data):
# Convert to NumPy array
img = np.array(data['composite'])
# print non-zero values
print("non-zero values", np.count_nonzero(img))
for i in range(img.shape[0]):
for j in range(img.shape[1]):
for k in range(img.shape[2]):
if img[i][j][k] != 0:
print("img[i][j][k]", img[i][j][k])
print("img.shape", img.shape)
# Handle RGBA or RGB images
if img.shape[-1] == 4: # RGBA
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
if img.shape[-1] == 3: # RGB
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Resize image to 28x28
img = cv2.resize(img, (28, 28))
# Normalize pixel values to [0, 1]
img = img / 255.0
# Reshape to match model input
img = img.reshape(1, 28, 28, 1)
print("img", img)
# Model predictions
preds = model.predict(img)[0]
print("preds", preds)
values_map = {preds[i]: i for i in range(len(preds))}
sorted_values = sorted(preds, reverse=True)
labels_map = dict()
for i in range(3):
print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]])
labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i]
print("labels_map", labels_map)
return labels_map
# Top 3 classes
label = gr.Label(num_top_classes=3)
# Open Gradio interface for sketch recognition
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(type='numpy', image_mode='L', brush=gr.Brush()),
outputs=label,
title=title,
description=head,
article=ref
)
interface.launch(share=True) |