Refactor predict function: streamline image extraction and add debug prints for image shape and content
Browse files
app.py
CHANGED
@@ -36,10 +36,10 @@ labels = {
|
|
36 |
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
|
37 |
|
38 |
def predict(data):
|
39 |
-
# Extract the 'composite' key from the input dictionary
|
40 |
-
img = data['composite']
|
41 |
# Convert to NumPy array
|
42 |
-
img = np.array(
|
|
|
|
|
43 |
|
44 |
# Handle RGBA or RGB images
|
45 |
if img.shape[-1] == 4: # RGBA
|
@@ -56,6 +56,8 @@ def predict(data):
|
|
56 |
# Reshape to match model input
|
57 |
img = img.reshape(1, 28, 28, 1)
|
58 |
|
|
|
|
|
59 |
# Model predictions
|
60 |
preds = model.predict(img)[0]
|
61 |
|
|
|
36 |
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
|
37 |
|
38 |
def predict(data):
|
|
|
|
|
39 |
# Convert to NumPy array
|
40 |
+
img = np.array(data['composite'])
|
41 |
+
|
42 |
+
print("img.shape", img.shape)
|
43 |
|
44 |
# Handle RGBA or RGB images
|
45 |
if img.shape[-1] == 4: # RGBA
|
|
|
56 |
# Reshape to match model input
|
57 |
img = img.reshape(1, 28, 28, 1)
|
58 |
|
59 |
+
print("img", img)
|
60 |
+
|
61 |
# Model predictions
|
62 |
preds = model.predict(img)[0]
|
63 |
|