File size: 1,388 Bytes
4c5bd22
a49ba8d
d4b4b25
cf0b1f5
a49ba8d
4c5bd22
 
a49ba8d
4c5bd22
e1b26df
4c5bd22
4747b05
4c5bd22
 
d4b4b25
 
e1b26df
4c5bd22
e1b26df
4c5bd22
 
d4b4b25
4c5bd22
 
c40b85e
4c5bd22
4747b05
676005c
4c5bd22
 
69bd373
4c5bd22
 
 
cf0b1f5
4c5bd22
 
cf0b1f5
4c5bd22
 
cf0b1f5
4c5bd22
 
cf0b1f5
4c5bd22
 
cf0b1f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# import dependencies
import gradio as gr
import tensorflow as tf
import cv2

# app title
title = "Welcome on your first sketch recognition app!"

# app description
head = (
  "<center>"
  "<img src='./mnist-classes.png' width=400>"
  "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
  "</center>"
)

# GitHub repository link
ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."

# image size: 28x28
img_size = 28

# classes name (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]

# load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")

# prediction function for sketch recognition
def predict(img):

  # image shape: 28x28x1
  img = cv2.resize(img, (img_size, img_size))
  img = img.reshape(1, img_size, img_size, 1)

  # model predictions
  preds = model.predict(img)[0]

  # return the probability for each classe
  return {label: float(pred) for label, pred in zip(labels, preds)}

# top 3 of classes
label = gr.outputs.Label(num_top_classes=3)

# open Gradio interface for sketch recognition
interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
interface.launch(share=True)