ariG23498 HF staff commited on
Commit
4e79da4
·
1 Parent(s): 9ac731e

add: fix prediction shape

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -61,6 +61,7 @@ def get_results(image):
61
  plt.axis("off")
62
 
63
  prediction = tf.nn.softmax(logits, axis=-1)
 
64
 
65
  return plt, {labels[i]: float(prediction[i]) for i in range(10)}
66
 
 
61
  plt.axis("off")
62
 
63
  prediction = tf.nn.softmax(logits, axis=-1)
64
+ prediction = prediction.numpy()[0]
65
 
66
  return plt, {labels[i]: float(prediction[i]) for i in range(10)}
67