alibayram commited on
Commit
448af73
·
1 Parent(s): d9d6ebf

Enhance sketch recognition app: improve image processing in prediction function and add NumPy dependency

Browse files
Files changed (2) hide show
  1. app.py +28 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -4,6 +4,7 @@ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable oneDNN optimizations
4
  import gradio as gr
5
  import tensorflow as tf
6
  import cv2
 
7
 
8
  # app title
9
  title = "Welcome on your first sketch recognition app!"
@@ -28,12 +29,34 @@ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight"
28
  # load model (trained on MNIST dataset)
29
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
30
 
31
- # prediction function for sketch recognition
32
  def predict(img):
33
- img = cv2.resize(img, (img_size, img_size))
34
- img = img.reshape(1, img_size, img_size, 1)
35
- preds = model.predict(img)[0]
36
- return {label: float(pred) for label, pred in zip(labels, preds)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # top 3 of classes
39
  label = gr.Label(num_top_classes=3)
 
4
  import gradio as gr
5
  import tensorflow as tf
6
  import cv2
7
+ import numpy as np
8
 
9
  # app title
10
  title = "Welcome on your first sketch recognition app!"
 
29
  # load model (trained on MNIST dataset)
30
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
31
 
32
+ # Prediction function for sketch recognition
33
  def predict(img):
34
+
35
+ try:
36
+ # Convert PIL image to NumPy array
37
+ img = np.array(img)
38
+
39
+ # Ensure grayscale format (convert from RGB if necessary)
40
+ if len(img.shape) == 3:
41
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
42
+
43
+ # Resize the image to 28x28
44
+ img = cv2.resize(img, (img_size, img_size))
45
+
46
+ # Normalize pixel values to [0, 1]
47
+ img = img / 255.0
48
+
49
+ # Reshape to match the model input shape
50
+ img = img.reshape(1, img_size, img_size, 1)
51
+
52
+ # Model predictions
53
+ preds = model.predict(img)[0]
54
+
55
+ # Return probabilities for each class
56
+ return {label: float(pred) for label, pred in zip(labels, preds)}
57
+
58
+ except Exception as e:
59
+ return {"error": f"Image processing failed: {str(e)}"}
60
 
61
  # top 3 of classes
62
  label = gr.Label(num_top_classes=3)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  tensorflow
2
  opencv-python
 
 
1
  tensorflow
2
  opencv-python
3
+ numpy