Refactor prediction function: change labels to a dictionary for better mapping and update return format to include top 3 class probabilities with corresponding class names
Browse files
app.py
CHANGED
@@ -20,8 +20,18 @@ ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples
|
|
20 |
|
21 |
|
22 |
# Class names (from 0 to 9)
|
23 |
-
labels =
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# Load model (trained on MNIST dataset)
|
26 |
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
|
27 |
|
@@ -49,17 +59,18 @@ def predict(data):
|
|
49 |
# Model predictions
|
50 |
preds = model.predict(img)[0]
|
51 |
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
|
56 |
|
57 |
-
|
|
|
|
|
|
|
58 |
|
59 |
-
""
|
60 |
-
|
61 |
-
return {label: float(pred) for label, pred in zip(labels, preds)} """
|
62 |
-
return {class_names[i]: top_3_probs[i] for i in range(3)}
|
63 |
|
64 |
# Top 3 classes
|
65 |
label = gr.Label(num_top_classes=3)
|
|
|
20 |
|
21 |
|
22 |
# Class names (from 0 to 9)
|
23 |
+
labels = {
|
24 |
+
0: "zero",
|
25 |
+
1: "one",
|
26 |
+
2: "two",
|
27 |
+
3: "three",
|
28 |
+
4: "four",
|
29 |
+
5: "five",
|
30 |
+
6: "six",
|
31 |
+
7: "seven",
|
32 |
+
8: "eight",
|
33 |
+
9: "nine"
|
34 |
+
}
|
35 |
# Load model (trained on MNIST dataset)
|
36 |
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
|
37 |
|
|
|
59 |
# Model predictions
|
60 |
preds = model.predict(img)[0]
|
61 |
|
62 |
+
print("preds", preds)
|
63 |
+
values_map = {preds[i]: i for i in range(len(preds))}
|
64 |
|
65 |
+
sorted_values = sorted(preds, reverse=True)
|
66 |
|
67 |
+
labels_map = dict()
|
68 |
+
for i in range(3):
|
69 |
+
print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]])
|
70 |
+
labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i]
|
71 |
|
72 |
+
print("labels_map", labels_map)
|
73 |
+
return labels_map
|
|
|
|
|
74 |
|
75 |
# Top 3 classes
|
76 |
label = gr.Label(num_top_classes=3)
|