Refactor sketch recognition app: disable oneDNN optimizations and update model path
Browse files
app.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
|
|
|
|
|
2 |
import gradio as gr
|
3 |
import tensorflow as tf
|
4 |
import cv2
|
@@ -24,23 +26,17 @@ img_size = 28
|
|
24 |
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
25 |
|
26 |
# load model (trained on MNIST dataset)
|
27 |
-
model = tf.keras.models.load_model("
|
28 |
|
29 |
# prediction function for sketch recognition
|
30 |
def predict(img):
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
# model predictions
|
37 |
-
preds = model.predict(img)[0]
|
38 |
-
|
39 |
-
# return the probability for each classe
|
40 |
-
return {label: float(pred) for label, pred in zip(labels, preds)}
|
41 |
|
42 |
# top 3 of classes
|
43 |
-
label = gr.
|
44 |
|
45 |
# open Gradio interface for sketch recognition
|
46 |
interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
|
|
|
1 |
+
import os
|
2 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable oneDNN optimizations
|
3 |
+
|
4 |
import gradio as gr
|
5 |
import tensorflow as tf
|
6 |
import cv2
|
|
|
26 |
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
27 |
|
28 |
# load model (trained on MNIST dataset)
|
29 |
+
model = tf.keras.models.load_model("model/sketch_recognition_numbers_model.h5")
|
30 |
|
31 |
# prediction function for sketch recognition
|
32 |
def predict(img):
|
33 |
+
img = cv2.resize(img, (img_size, img_size))
|
34 |
+
img = img.reshape(1, img_size, img_size, 1)
|
35 |
+
preds = model.predict(img)[0]
|
36 |
+
return {label: float(pred) for label, pred in zip(labels, preds)}
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# top 3 of classes
|
39 |
+
label = gr.Label(num_top_classes=3)
|
40 |
|
41 |
# open Gradio interface for sketch recognition
|
42 |
interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
|