alibayram commited on
Commit
892b132
·
1 Parent(s): d6bce58

Refactor prediction function: streamline prediction logic and update return format to include probabilities for all classes

Browse files
Files changed (1) hide show
  1. app.py +5 -23
app.py CHANGED
@@ -25,22 +25,6 @@ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight"
25
  # Load model (trained on MNIST dataset)
26
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
27
 
28
- """ # Prediction function for sketch recognition
29
- def predict(data):
30
- print(data['composite'].shape)
31
- # Reshape image to 28x28
32
- img = np.reshape(data['composite'], (1, img_size, img_size, 1))
33
- # Make prediction
34
- pred = model.predict(img)
35
- # Get top class
36
- top_3_classes = np.argsort(pred[0])[-3:][::-1]
37
- # Get top 3 probabilities
38
- top_3_probs = pred[0][top_3_classes]
39
- # Get class names
40
- class_names = [labels[i] for i in top_3_classes]
41
- # Return class names and probabilities
42
- return {class_names[i]: top_3_probs[i] for i in range(3)} """
43
-
44
  def predict(data):
45
  # Extract the 'composite' key from the input dictionary
46
  img = data['composite']
@@ -63,20 +47,18 @@ def predict(data):
63
  img = img.reshape(1, 28, 28, 1)
64
 
65
  # Model predictions
66
- preds = model.predict(img)
67
-
68
- print(preds)
69
-
70
- preds = preds[0]
71
- print(preds)
72
 
73
  top_3_classes = np.argsort(preds)[-3:][::-1]
74
  top_3_probs = preds[top_3_classes]
75
 
76
  class_names = [labels[i] for i in top_3_classes]
77
 
78
- print(class_names, top_3_probs, top_3_classes)
79
 
 
 
 
80
  return {class_names[i]: top_3_probs[i] for i in range(3)}
81
 
82
  # Top 3 classes
 
25
  # Load model (trained on MNIST dataset)
26
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def predict(data):
29
  # Extract the 'composite' key from the input dictionary
30
  img = data['composite']
 
47
  img = img.reshape(1, 28, 28, 1)
48
 
49
  # Model predictions
50
+ preds = model.predict(img)[0]
 
 
 
 
 
51
 
52
  top_3_classes = np.argsort(preds)[-3:][::-1]
53
  top_3_probs = preds[top_3_classes]
54
 
55
  class_names = [labels[i] for i in top_3_classes]
56
 
57
+ print("class_names, top_3_probs, top_3_classes" , class_names, top_3_probs, top_3_classes)
58
 
59
+ """ return {class_names[i]: top_3_probs[i] for i in range(3)} """
60
+ """ # return the probability for each classe
61
+ return {label: float(pred) for label, pred in zip(labels, preds)} """
62
  return {class_names[i]: top_3_probs[i] for i in range(3)}
63
 
64
  # Top 3 classes