nathan ayers commited on
Commit
14c8643
·
verified ·
1 Parent(s): 0e06c36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -19
app.py CHANGED
@@ -1,26 +1,25 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
  import pickle
4
  import numpy as np
5
  from PIL import Image
 
6
 
7
- app = FastAPI()
8
  model = pickle.load(open("mnist_model.pkl", "rb"))
9
 
10
- def preprocess_image(file_bytes) -> np.ndarray:
11
- # 1) Load into PIL, convert to grayscale 'L'
12
- img = Image.open(file_bytes).convert("L")
13
- # 2) Resize to 28×28 (use ANTIALIAS for quality)
14
- img = img.resize((28,28), Image.ANTIALIAS)
15
- # 3) Convert to numpy array (uint8), flatten to length-784
16
- arr = np.array(img).astype("uint8").reshape(1, -1)
17
- # 4) Optionally invert colors if your MNIST is white-on-black:
18
- # arr = 255 - arr
19
- return arr
20
-
21
- @app.post("/predict-image/")
22
- async def predict_image(file: UploadFile = File(...)):
23
- # read the incoming UploadFile into BytesIO
24
- arr = preprocess_image(file.file)
25
  pred = model.predict(arr)[0]
26
- return JSONResponse({"prediction": int(pred)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pickle
2
  import numpy as np
3
  from PIL import Image
4
+ import gradio as gr
5
 
6
+ # load your pickled RandomForest
7
  model = pickle.load(open("mnist_model.pkl", "rb"))
8
 
9
+ def classify_digit(img: Image.Image) -> str:
10
+ # convert to 28×28 grayscale
11
+ gray = img.convert("L").resize((28, 28))
12
+ arr = np.array(gray).reshape(1, -1)
 
 
 
 
 
 
 
 
 
 
 
13
  pred = model.predict(arr)[0]
14
+ return f"Predicted digit: {pred}"
15
+
16
+ demo = gr.Interface(
17
+ fn=classify_digit,
18
+ inputs=gr.inputs.Image(type="pil", label="Upload a 28×28 digit"),
19
+ outputs=gr.outputs.Textbox(label="Prediction"),
20
+ title="Digit Classifier",
21
+ description="Upload a handwritten MNIST digit and get a prediction!"
22
+ )
23
+
24
+ if __name__ == "__main__":
25
+ demo.launch()