alibayram commited on
Commit
5082ec7
·
1 Parent(s): d4b4b25

Refactor sketch recognition app: disable oneDNN optimizations and update model path

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -1,4 +1,6 @@
1
- # import dependencies
 
 
2
  import gradio as gr
3
  import tensorflow as tf
4
  import cv2
@@ -24,23 +26,17 @@ img_size = 28
24
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
25
 
26
  # load model (trained on MNIST dataset)
27
- model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
28
 
29
  # prediction function for sketch recognition
30
  def predict(img):
31
-
32
- # image shape: 28x28x1
33
- img = cv2.resize(img, (img_size, img_size))
34
- img = img.reshape(1, img_size, img_size, 1)
35
-
36
- # model predictions
37
- preds = model.predict(img)[0]
38
-
39
- # return the probability for each classe
40
- return {label: float(pred) for label, pred in zip(labels, preds)}
41
 
42
  # top 3 of classes
43
- label = gr.outputs.Label(num_top_classes=3)
44
 
45
  # open Gradio interface for sketch recognition
46
  interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
 
1
+ import os
2
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable oneDNN optimizations
3
+
4
  import gradio as gr
5
  import tensorflow as tf
6
  import cv2
 
26
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
27
 
28
  # load model (trained on MNIST dataset)
29
+ model = tf.keras.models.load_model("model/sketch_recognition_numbers_model.h5")
30
 
31
  # prediction function for sketch recognition
32
  def predict(img):
33
+ img = cv2.resize(img, (img_size, img_size))
34
+ img = img.reshape(1, img_size, img_size, 1)
35
+ preds = model.predict(img)[0]
36
+ return {label: float(pred) for label, pred in zip(labels, preds)}
 
 
 
 
 
 
37
 
38
  # top 3 of classes
39
+ label = gr.Label(num_top_classes=3)
40
 
41
  # open Gradio interface for sketch recognition
42
  interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)