Refactor prediction function: update data logging, reshape input handling, and simplify sketchpad configuration
Browse files
app.py
CHANGED
@@ -28,16 +28,12 @@ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
|
|
28 |
|
29 |
# Prediction function for sketch recognition
|
30 |
def predict(data):
|
31 |
-
print(data)
|
32 |
-
# ValueError: cannot reshape array of size 1 into shape (1,28,28,1)
|
33 |
-
|
34 |
# Reshape image to 28x28
|
35 |
img = np.reshape(data['composite'], (1, img_size, img_size, 1))
|
36 |
# Make prediction
|
37 |
pred = model.predict(img)
|
38 |
# Get top class
|
39 |
-
top_class = np.argmax
|
40 |
-
# Get top 3 classes
|
41 |
top_3_classes = np.argsort(pred[0])[-3:][::-1]
|
42 |
# Get top 3 probabilities
|
43 |
top_3_probs = pred[0][top_3_classes]
|
@@ -52,7 +48,7 @@ label = gr.Label(num_top_classes=3)
|
|
52 |
# Open Gradio interface for sketch recognition
|
53 |
interface = gr.Interface(
|
54 |
fn=predict,
|
55 |
-
inputs=gr.Sketchpad(
|
56 |
outputs=label,
|
57 |
title=title,
|
58 |
description=head,
|
|
|
28 |
|
29 |
# Prediction function for sketch recognition
|
30 |
def predict(data):
|
31 |
+
print(data.shape)
|
|
|
|
|
32 |
# Reshape image to 28x28
|
33 |
img = np.reshape(data['composite'], (1, img_size, img_size, 1))
|
34 |
# Make prediction
|
35 |
pred = model.predict(img)
|
36 |
# Get top class
|
|
|
|
|
37 |
top_3_classes = np.argsort(pred[0])[-3:][::-1]
|
38 |
# Get top 3 probabilities
|
39 |
top_3_probs = pred[0][top_3_classes]
|
|
|
48 |
# Open Gradio interface for sketch recognition
|
49 |
interface = gr.Interface(
|
50 |
fn=predict,
|
51 |
+
inputs=gr.Sketchpad(type='numpy'),
|
52 |
outputs=label,
|
53 |
title=title,
|
54 |
description=head,
|