File size: 1,388 Bytes
4c5bd22 a49ba8d d4b4b25 cf0b1f5 a49ba8d 4c5bd22 a49ba8d 4c5bd22 e1b26df 4c5bd22 4747b05 4c5bd22 d4b4b25 e1b26df 4c5bd22 e1b26df 4c5bd22 d4b4b25 4c5bd22 c40b85e 4c5bd22 4747b05 676005c 4c5bd22 69bd373 4c5bd22 cf0b1f5 4c5bd22 cf0b1f5 4c5bd22 cf0b1f5 4c5bd22 cf0b1f5 4c5bd22 cf0b1f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# import dependencies
import gradio as gr
import tensorflow as tf
import cv2
# app title
title = "Welcome on your first sketch recognition app!"
# app description
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
"</center>"
)
# GitHub repository link
ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
# image size: 28x28
img_size = 28
# classes name (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
# load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
# prediction function for sketch recognition
def predict(img):
# image shape: 28x28x1
img = cv2.resize(img, (img_size, img_size))
img = img.reshape(1, img_size, img_size, 1)
# model predictions
preds = model.predict(img)[0]
# return the probability for each classe
return {label: float(pred) for label, pred in zip(labels, preds)}
# top 3 of classes
label = gr.outputs.Label(num_top_classes=3)
# open Gradio interface for sketch recognition
interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
interface.launch(share=True) |