alibayram commited on
Commit
cd17133
·
1 Parent(s): 5efe090

Refactor sketch recognition app: enhance image preprocessing, improve error handling, and update app description

Browse files
Files changed (1) hide show
  1. app.py +61 -31
app.py CHANGED
@@ -1,47 +1,77 @@
1
- # import dependencies
 
 
2
  import gradio as gr
3
  import tensorflow as tf
4
- import cv2
5
 
6
- # app title
7
- title = "Welcome on your first sketch recognition app!"
8
 
9
- # app description
10
- head = (
11
- "<center>"
12
- "<img src='./mnist-classes.png' width=400>"
13
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
14
- "</center>"
 
 
15
  )
 
16
 
17
- # GitHub repository link
18
- ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
19
-
20
- # image size: 28x28
21
  img_size = 28
22
-
23
- # classes name (from 0 to 9)
24
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
25
 
26
- # load model (trained on MNIST dataset)
27
- model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
 
 
 
 
28
 
29
- # prediction function for sketch recognition
30
- def predict(img):
31
 
32
- # image shape: 28x28x1
33
- img = cv2.resize(img, (img_size, img_size))
34
- img = img.reshape(1, img_size, img_size, 1)
 
 
 
 
 
 
 
35
 
36
- # model predictions
37
- preds = model.predict(img)[0]
38
 
39
- # return the probability for each classe
40
- return {label: float(pred) for label, pred in zip(labels, preds)}
41
 
42
- # top 3 of classes
43
- label = gr.Label(num_top_classes=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # open Gradio interface for sketch recognition
46
- interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
47
  interface.launch()
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
  import gradio as gr
5
  import tensorflow as tf
 
6
 
7
+ # Disable oneDNN optimizations for consistent results
8
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
9
 
10
+ # App configuration
11
+ title = "Welcome to your first sketch recognition app!"
12
+ description = (
13
+ "<center>"
14
+ "<img src='mnist-classes.png' width=400>"
15
+ "<p>The robot was trained to classify numbers (from 0 to 9). "
16
+ "To test it, write your number in the space provided!</p>"
17
+ "</center>"
18
  )
19
+ article = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
20
 
21
+ # Image size and labels
 
 
 
22
  img_size = 28
 
 
23
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
24
 
25
+ # Load the trained MNIST model
26
+ model_path = "./sketch_recognition_numbers_model.h5"
27
+ try:
28
+ model = tf.keras.models.load_model(model_path)
29
+ except Exception as e:
30
+ raise FileNotFoundError(f"Model file '{model_path}' not found or failed to load. {str(e)}")
31
 
 
 
32
 
33
+ def preprocess_image(img):
34
+ """
35
+ Convert PIL image to grayscale NumPy array, resize, normalize, and reshape.
36
+ """
37
+ # Convert PIL to NumPy array
38
+ img = np.array(img)
39
+
40
+ # Ensure grayscale format
41
+ if len(img.shape) == 3: # Check if it's RGB/RGBA
42
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
43
 
44
+ # Resize to 28x28
45
+ img = cv2.resize(img, (img_size, img_size))
46
 
47
+ # Normalize pixel values to [0, 1]
48
+ img = img / 255.0
49
 
50
+ # Reshape for model input
51
+ return img.reshape(1, img_size, img_size, 1)
52
+
53
+
54
+ def predict(img):
55
+ """
56
+ Predict the digit class probabilities from the input sketch image.
57
+ """
58
+ try:
59
+ processed_img = preprocess_image(img)
60
+ predictions = model.predict(processed_img)[0]
61
+ return {label: float(pred) for label, pred in zip(labels, predictions)}
62
+ except Exception as e:
63
+ return {"error": f"Prediction failed: {str(e)}"}
64
+
65
+
66
+ # Gradio interface
67
+ interface = gr.Interface(
68
+ fn=predict,
69
+ inputs="sketchpad",
70
+ outputs=gr.Label(num_top_classes=3),
71
+ title=title,
72
+ description=description,
73
+ article=article,
74
+ )
75
 
76
+ # Launch the app
 
77
  interface.launch()