nathan ayers commited on
Commit
3b750f3
·
verified ·
1 Parent(s): 14c8643

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -3,11 +3,11 @@ import numpy as np
3
  from PIL import Image
4
  import gradio as gr
5
 
6
- # load your pickled RandomForest
7
  model = pickle.load(open("mnist_model.pkl", "rb"))
8
 
9
  def classify_digit(img: Image.Image) -> str:
10
- # convert to 28×28 grayscale
11
  gray = img.convert("L").resize((28, 28))
12
  arr = np.array(gray).reshape(1, -1)
13
  pred = model.predict(arr)[0]
@@ -15,10 +15,10 @@ def classify_digit(img: Image.Image) -> str:
15
 
16
  demo = gr.Interface(
17
  fn=classify_digit,
18
- inputs=gr.inputs.Image(type="pil", label="Upload a 28×28 digit"),
19
- outputs=gr.outputs.Textbox(label="Prediction"),
20
- title="Digit Classifier",
21
- description="Upload a handwritten MNIST digit and get a prediction!"
22
  )
23
 
24
  if __name__ == "__main__":
 
3
  from PIL import Image
4
  import gradio as gr
5
 
6
+ # load your pickled RandomForest (make sure mnist_model.pkl lives in /app)
7
  model = pickle.load(open("mnist_model.pkl", "rb"))
8
 
9
  def classify_digit(img: Image.Image) -> str:
10
+ # convert to 28×28 grayscale array
11
  gray = img.convert("L").resize((28, 28))
12
  arr = np.array(gray).reshape(1, -1)
13
  pred = model.predict(arr)[0]
 
15
 
16
  demo = gr.Interface(
17
  fn=classify_digit,
18
+ inputs=gr.Image(type="pil", label="Upload a 28×28 digit"),
19
+ outputs=gr.Textbox(label="Prediction"),
20
+ title="MNIST Digit Classifier",
21
+ description="Upload a handwritten digit image (28×28) to get a live prediction!"
22
  )
23
 
24
  if __name__ == "__main__":