File size: 2,466 Bytes
051f92c
a49ba8d
d4b4b25
cf0b1f5
a49ba8d
e6fdb4c
 
a49ba8d
e6fdb4c
e1b26df
e6fdb4c
 
 
 
 
d4b4b25
 
e1b26df
e6fdb4c
e1b26df
d4b4b25
e6fdb4c
6c3c8f8
 
 
 
 
 
 
 
 
 
 
 
e6fdb4c
 
c40b85e
cf0b1f5
 
676005c
 
69bd373
 
 
 
 
 
 
 
676005c
cf0b1f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676005c
 
cf0b1f5
892b132
d6bce58
6c3c8f8
 
363a74f
6c3c8f8
d6bce58
6c3c8f8
 
 
 
cf0b1f5
6c3c8f8
 
051f92c
e6fdb4c
 
 
 
c40b85e
051f92c
69bd373
e6fdb4c
051f92c
 
c40b85e
 
cf0b1f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2

# App title
title = "Welcome to your first sketch recognition app!"

# App description
head = (
    "<center>"
    "<img src='./mnist-classes.png' width=400>"
    "<p>The model is trained to classify numbers (from 0 to 9). "
    "To test it, draw your number in the space provided.</p>"
    "</center>"
)

# GitHub repository link
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."


# Class names (from 0 to 9)
labels = {
    0: "zero", 
    1: "one", 
    2: "two",
    3: "three",
    4: "four",
    5: "five",
    6: "six",
    7: "seven",
    8: "eight",
    9: "nine"    
}
# Load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")

def predict(data):
    # Convert to NumPy array
    img = np.array(data['composite'])

    # print non-zero values
    print("non-zero values", np.count_nonzero(img))
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            for k in range(img.shape[2]):
                if img[i][j][k] != 0:
                    print("img[i][j][k]", img[i][j][k])

    print("img.shape", img.shape)

    # Handle RGBA or RGB images
    if img.shape[-1] == 4:  # RGBA
        img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
    if img.shape[-1] == 3:  # RGB
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

    # Resize image to 28x28
    img = cv2.resize(img, (28, 28))

    # Normalize pixel values to [0, 1]
    img = img / 255.0

    # Reshape to match model input
    img = img.reshape(1, 28, 28, 1)

    print("img", img)

    # Model predictions
    preds = model.predict(img)[0]

    print("preds", preds)
    values_map = {preds[i]: i for i in range(len(preds))}
    
    sorted_values = sorted(preds, reverse=True)

    labels_map = dict()
    for i in range(3):
        print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]])
        labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i]

    print("labels_map", labels_map)
    return labels_map

# Top 3 classes
label = gr.Label(num_top_classes=3)

# Open Gradio interface for sketch recognition
interface = gr.Interface(
    fn=predict,
    inputs=gr.Sketchpad(type='numpy', image_mode='L', brush=gr.Brush()),
    outputs=label,
    title=title,
    description=head,
    article=ref
)
interface.launch(share=True)