alibayram commited on
Commit
fec34a0
·
1 Parent(s): 1d98dcd

Refactor prediction function: update data logging, reshape input handling, and simplify sketchpad configuration

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -28,16 +28,12 @@ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
28
 
29
  # Prediction function for sketch recognition
30
  def predict(data):
31
- print(data)
32
- # ValueError: cannot reshape array of size 1 into shape (1,28,28,1)
33
-
34
  # Reshape image to 28x28
35
  img = np.reshape(data['composite'], (1, img_size, img_size, 1))
36
  # Make prediction
37
  pred = model.predict(img)
38
  # Get top class
39
- top_class = np.argmax
40
- # Get top 3 classes
41
  top_3_classes = np.argsort(pred[0])[-3:][::-1]
42
  # Get top 3 probabilities
43
  top_3_probs = pred[0][top_3_classes]
@@ -52,7 +48,7 @@ label = gr.Label(num_top_classes=3)
52
  # Open Gradio interface for sketch recognition
53
  interface = gr.Interface(
54
  fn=predict,
55
- inputs=gr.Sketchpad(crop_size=(28,28), type='numpy', image_mode='L', brush=gr.Brush()),
56
  outputs=label,
57
  title=title,
58
  description=head,
 
28
 
29
  # Prediction function for sketch recognition
30
  def predict(data):
31
+ print(data.shape)
 
 
32
  # Reshape image to 28x28
33
  img = np.reshape(data['composite'], (1, img_size, img_size, 1))
34
  # Make prediction
35
  pred = model.predict(img)
36
  # Get top class
 
 
37
  top_3_classes = np.argsort(pred[0])[-3:][::-1]
38
  # Get top 3 probabilities
39
  top_3_probs = pred[0][top_3_classes]
 
48
  # Open Gradio interface for sketch recognition
49
  interface = gr.Interface(
50
  fn=predict,
51
+ inputs=gr.Sketchpad(type='numpy'),
52
  outputs=label,
53
  title=title,
54
  description=head,