alibayram commited on
Commit
051f92c
·
1 Parent(s): d9d330c

Refactor sketch recognition app: enhance image handling, improve error messages, and update app description

Browse files
Files changed (2) hide show
  1. app.py +47 -40
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,66 +1,73 @@
1
- # import dependencies
 
 
2
  import gradio as gr
3
  import tensorflow as tf
4
- import cv2
5
- import numpy as np
6
 
7
  # app title
8
- title = "Welcome to your first sketch recognition app!"
9
 
10
  # app description
11
  head = (
12
- "<center>"
13
- "<img src='mnist-classes.png' width=400><br>"
14
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
15
- "</center>"
 
16
  )
17
 
18
  # GitHub repository link
19
  ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
20
 
21
- # image size: 28x28
22
  img_size = 28
23
 
24
- # classes name (from 0 to 9)
25
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
26
 
27
- # load model (trained on MNIST dataset)
28
- model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
 
 
 
 
29
 
30
- # prediction function for sketch recognition
31
  def predict(img):
32
- if img is not None:
33
- # Convert to numpy array if not already
34
- img = np.array(img)
35
-
36
- # Ensure grayscale
37
- if len(img.shape) == 3:
38
- img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
39
-
40
- # Resize to required dimensions
41
- img = cv2.resize(img, (img_size, img_size))
42
-
43
- # Normalize and reshape
44
- img = img.astype('float32') / 255.0
45
- img = img.reshape(1, img_size, img_size, 1)
46
 
47
- # model predictions
48
- preds = model.predict(img)[0]
 
49
 
50
- # return the probability for each class
51
- return {label: float(pred) for label, pred in zip(labels, preds)}
52
- return None
 
 
53
 
54
- # top 3 of classes
55
- label = gr.Label(num_top_classes=3)
56
 
57
- # open Gradio interface for sketch recognition
 
 
 
 
 
 
 
 
 
58
  interface = gr.Interface(
59
- fn=predict,
60
- inputs=gr.Sketchpad(height=280, width=280), # Changed from shape to height and width
61
- outputs=label,
62
- title=title,
63
- description=head,
64
  article=ref
65
  )
 
66
  interface.launch()
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
  import gradio as gr
5
  import tensorflow as tf
6
+ from PIL import Image
 
7
 
8
  # app title
9
+ title = "Welcome on your first sketch recognition app!"
10
 
11
  # app description
12
  head = (
13
+ "<center>"
14
+ "<img src='./mnist-classes.png' width=400>"
15
+ "<p>The robot was trained to classify numbers (0 to 9). "
16
+ "To test it, write your number in the space provided!</p>"
17
+ "</center>"
18
  )
19
 
20
  # GitHub repository link
21
  ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
22
 
23
+ # Image size
24
  img_size = 28
25
 
26
+ # Classes
27
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
28
 
29
+ # Load model
30
+ model_path = "./sketch_recognition_numbers_model.h5"
31
+ try:
32
+ model = tf.keras.models.load_model(model_path)
33
+ except Exception as e:
34
+ raise FileNotFoundError(f"Model file '{model_path}' not found or failed to load. {str(e)}")
35
 
 
36
  def predict(img):
37
+ # If no image is provided, return an error message
38
+ if img is None:
39
+ return {"error": "No image provided."}
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Ensure the image is a PIL Image
42
+ if not isinstance(img, Image.Image):
43
+ img = Image.fromarray(np.uint8(img))
44
 
45
+ # Convert to grayscale
46
+ img = img.convert("L")
47
+
48
+ # Convert PIL Image to a NumPy array of type uint8
49
+ img = np.array(img, dtype=np.uint8)
50
 
51
+ # Resize to (28x28)
52
+ img = cv2.resize(img, (img_size, img_size))
53
 
54
+ # Reshape to match model input shape (1, 28, 28, 1)
55
+ img = img.reshape(1, img_size, img_size, 1)
56
+
57
+ # Model predictions
58
+ preds = model.predict(img)[0]
59
+
60
+ # Return probabilities for each class
61
+ return {label: float(pred) for label, pred in zip(labels, preds)}
62
+
63
+ # Use gr.Sketchpad to ensure a PIL image is returned
64
  interface = gr.Interface(
65
+ fn=predict,
66
+ inputs=gr.Sketchpad(type="pil"),
67
+ outputs=gr.Label(num_top_classes=3),
68
+ title=title,
69
+ description=head,
70
  article=ref
71
  )
72
+
73
  interface.launch()
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  tensorflow
2
  opencv-python-headless
3
  numpy
 
 
 
1
  tensorflow
2
  opencv-python-headless
3
  numpy
4
+ # PIL
5
+ Pillow