alibayram commited on
Commit
cf0b1f5
·
1 Parent(s): 1944562

Refactor prediction function: enhance image preprocessing, convert to grayscale, and update model input handling

Browse files
Files changed (2) hide show
  1. app.py +31 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import numpy as np
2
  import gradio as gr
3
  import tensorflow as tf
 
4
 
5
  # App title
6
  title = "Welcome to your first sketch recognition app!"
@@ -17,8 +18,6 @@ head = (
17
  # GitHub repository link
18
  ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
19
 
20
- # Image size: 28x28
21
- img_size = 28
22
 
23
  # Class names (from 0 to 9)
24
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
@@ -26,7 +25,7 @@ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight"
26
  # Load model (trained on MNIST dataset)
27
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
28
 
29
- # Prediction function for sketch recognition
30
  def predict(data):
31
  print(data['composite'].shape)
32
  # Reshape image to 28x28
@@ -40,7 +39,34 @@ def predict(data):
40
  # Get class names
41
  class_names = [labels[i] for i in top_3_classes]
42
  # Return class names and probabilities
43
- return {class_names[i]: top_3_probs[i] for i in range(3)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Top 3 classes
46
  label = gr.Label(num_top_classes=3)
@@ -54,4 +80,4 @@ interface = gr.Interface(
54
  description=head,
55
  article=ref
56
  )
57
- interface.launch()
 
1
  import numpy as np
2
  import gradio as gr
3
  import tensorflow as tf
4
+ import cv2
5
 
6
  # App title
7
  title = "Welcome to your first sketch recognition app!"
 
18
  # GitHub repository link
19
  ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
20
 
 
 
21
 
22
  # Class names (from 0 to 9)
23
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
 
25
  # Load model (trained on MNIST dataset)
26
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
27
 
28
+ """ # Prediction function for sketch recognition
29
  def predict(data):
30
  print(data['composite'].shape)
31
  # Reshape image to 28x28
 
39
  # Get class names
40
  class_names = [labels[i] for i in top_3_classes]
41
  # Return class names and probabilities
42
+ return {class_names[i]: top_3_probs[i] for i in range(3)} """
43
+
44
+ def predict(data):
45
+ # Extract the 'image' key from the input dictionary
46
+ img = data['image']
47
+ # Convert to NumPy array
48
+ img = np.array(img)
49
+
50
+ # Handle RGBA or RGB images
51
+ if img.shape[-1] == 4: # RGBA
52
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
53
+ if img.shape[-1] == 3: # RGB
54
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
55
+
56
+ # Resize image to 28x28
57
+ img = cv2.resize(img, (28, 28))
58
+
59
+ # Normalize pixel values to [0, 1]
60
+ img = img / 255.0
61
+
62
+ # Reshape to match model input
63
+ img = img.reshape(1, 28, 28, 1)
64
+
65
+ # Model predictions
66
+ preds = model.predict(img)[0]
67
+
68
+ # Return the probability for each class
69
+ return {label: float(pred) for label, pred in zip(labels, preds)}
70
 
71
  # Top 3 classes
72
  label = gr.Label(num_top_classes=3)
 
80
  description=head,
81
  article=ref
82
  )
83
+ interface.launch(share=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  tensorflow
 
2
  numpy
 
1
  tensorflow
2
+ opencv-python-headless
3
  numpy