alibayram commited on
Commit
6c3c8f8
·
1 Parent(s): 892b132

Refactor prediction function: change labels to a dictionary for better mapping and update return format to include top 3 class probabilities with corresponding class names

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -20,8 +20,18 @@ ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples
20
 
21
 
22
  # Class names (from 0 to 9)
23
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
24
-
 
 
 
 
 
 
 
 
 
 
25
  # Load model (trained on MNIST dataset)
26
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
27
 
@@ -49,17 +59,18 @@ def predict(data):
49
  # Model predictions
50
  preds = model.predict(img)[0]
51
 
52
- top_3_classes = np.argsort(preds)[-3:][::-1]
53
- top_3_probs = preds[top_3_classes]
54
 
55
- class_names = [labels[i] for i in top_3_classes]
56
 
57
- print("class_names, top_3_probs, top_3_classes" , class_names, top_3_probs, top_3_classes)
 
 
 
58
 
59
- """ return {class_names[i]: top_3_probs[i] for i in range(3)} """
60
- """ # return the probability for each classe
61
- return {label: float(pred) for label, pred in zip(labels, preds)} """
62
- return {class_names[i]: top_3_probs[i] for i in range(3)}
63
 
64
  # Top 3 classes
65
  label = gr.Label(num_top_classes=3)
 
20
 
21
 
22
  # Class names (from 0 to 9)
23
+ labels = {
24
+ 0: "zero",
25
+ 1: "one",
26
+ 2: "two",
27
+ 3: "three",
28
+ 4: "four",
29
+ 5: "five",
30
+ 6: "six",
31
+ 7: "seven",
32
+ 8: "eight",
33
+ 9: "nine"
34
+ }
35
  # Load model (trained on MNIST dataset)
36
  model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
37
 
 
59
  # Model predictions
60
  preds = model.predict(img)[0]
61
 
62
+ print("preds", preds)
63
+ values_map = {preds[i]: i for i in range(len(preds))}
64
 
65
+ sorted_values = sorted(preds, reverse=True)
66
 
67
+ labels_map = dict()
68
+ for i in range(3):
69
+ print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]])
70
+ labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i]
71
 
72
+ print("labels_map", labels_map)
73
+ return labels_map
 
 
74
 
75
  # Top 3 classes
76
  label = gr.Label(num_top_classes=3)