alibayram commited on
Commit
e6fdb4c
·
1 Parent(s): 051f92c

Refactor sketch recognition app: update app title and description, streamline image processing, and enhance prediction function

Browse files
Files changed (2) hide show
  1. app.py +32 -44
  2. requirements.txt +1 -3
app.py CHANGED
@@ -1,73 +1,61 @@
1
- import os
2
  import numpy as np
3
- import cv2
4
  import gradio as gr
5
  import tensorflow as tf
6
- from PIL import Image
7
 
8
- # app title
9
- title = "Welcome on your first sketch recognition app!"
10
 
11
- # app description
12
  head = (
13
- "<center>"
14
- "<img src='./mnist-classes.png' width=400>"
15
- "<p>The robot was trained to classify numbers (0 to 9). "
16
- "To test it, write your number in the space provided!</p>"
17
- "</center>"
18
  )
19
 
20
  # GitHub repository link
21
- ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
22
 
23
- # Image size
24
  img_size = 28
25
 
26
- # Classes
27
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
28
 
29
- # Load model
30
- model_path = "./sketch_recognition_numbers_model.h5"
31
- try:
32
- model = tf.keras.models.load_model(model_path)
33
- except Exception as e:
34
- raise FileNotFoundError(f"Model file '{model_path}' not found or failed to load. {str(e)}")
35
-
36
- def predict(img):
37
- # If no image is provided, return an error message
38
- if img is None:
39
- return {"error": "No image provided."}
40
-
41
- # Ensure the image is a PIL Image
42
- if not isinstance(img, Image.Image):
43
- img = Image.fromarray(np.uint8(img))
44
 
 
 
 
 
 
 
45
  # Convert to grayscale
46
- img = img.convert("L")
47
-
48
- # Convert PIL Image to a NumPy array of type uint8
49
- img = np.array(img, dtype=np.uint8)
50
-
51
- # Resize to (28x28)
52
  img = cv2.resize(img, (img_size, img_size))
53
-
54
- # Reshape to match model input shape (1, 28, 28, 1)
 
55
  img = img.reshape(1, img_size, img_size, 1)
56
-
57
  # Model predictions
58
  preds = model.predict(img)[0]
59
-
60
- # Return probabilities for each class
61
  return {label: float(pred) for label, pred in zip(labels, preds)}
62
 
63
- # Use gr.Sketchpad to ensure a PIL image is returned
 
 
 
64
  interface = gr.Interface(
65
  fn=predict,
66
- inputs=gr.Sketchpad(type="pil"),
67
- outputs=gr.Label(num_top_classes=3),
68
  title=title,
69
  description=head,
70
  article=ref
71
  )
72
-
73
  interface.launch()
 
 
1
  import numpy as np
 
2
  import gradio as gr
3
  import tensorflow as tf
4
+ import cv2
5
 
6
+ # App title
7
+ title = "Welcome to your first sketch recognition app!"
8
 
9
+ # App description
10
  head = (
11
+ "<center>"
12
+ "<img src='./mnist-classes.png' width=400>"
13
+ "<p>The model is trained to classify numbers (from 0 to 9). "
14
+ "To test it, draw your number in the space provided.</p>"
15
+ "</center>"
16
  )
17
 
18
  # GitHub repository link
19
+ ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
20
 
21
+ # Image size: 28x28
22
  img_size = 28
23
 
24
+ # Class names (from 0 to 9)
25
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
26
 
27
+ # Load model (trained on MNIST dataset)
28
+ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Prediction function for sketch recognition
31
+ def predict(data):
32
+ # Extract the 'image' key from the input dictionary
33
+ img = data['image']
34
+ # Convert to NumPy array
35
+ img = np.array(img)
36
  # Convert to grayscale
37
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
38
+ # Resize image to 28x28
 
 
 
 
39
  img = cv2.resize(img, (img_size, img_size))
40
+ # Normalize pixel values
41
+ img = img / 255.0
42
+ # Reshape image to match model input
43
  img = img.reshape(1, img_size, img_size, 1)
 
44
  # Model predictions
45
  preds = model.predict(img)[0]
46
+ # Return the probability for each class
 
47
  return {label: float(pred) for label, pred in zip(labels, preds)}
48
 
49
+ # Top 3 classes
50
+ label = gr.Label(num_top_classes=3)
51
+
52
+ # Open Gradio interface for sketch recognition
53
  interface = gr.Interface(
54
  fn=predict,
55
+ inputs=gr.Sketchpad(),
56
+ outputs=label,
57
  title=title,
58
  description=head,
59
  article=ref
60
  )
 
61
  interface.launch()
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
  tensorflow
2
  opencv-python-headless
3
- numpy
4
- # PIL
5
- Pillow
 
1
  tensorflow
2
  opencv-python-headless
3
+ numpy