nathan ayers commited on
Commit
3fa79b6
·
verified ·
1 Parent(s): d0210a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -20
app.py CHANGED
@@ -1,27 +1,26 @@
 
 
1
  import pickle
2
  import numpy as np
3
  from PIL import Image
4
- import gradio as gr
5
 
6
- # 1) Load your pretrained model
7
  model = pickle.load(open("mnist_model.pkl", "rb"))
8
 
9
- # 2) Define a prediction function
10
- def classify_digit(img):
11
- # convert to grayscale 28×28
12
- img = img.convert("L").resize((28, 28))
13
- arr = np.array(img).reshape(1, -1)
14
- pred = model.predict(arr)[0]
15
- return f"Predicted digit: {pred}"
16
-
17
- # 3) Wire up Gradio
18
- iface = gr.Interface(
19
- fn=classify_digit,
20
- inputs=gr.Image(type="pil", label="Upload a 28×28 digit"),
21
- outputs=gr.Textbox(label="Prediction"),
22
- title="MNIST Digit Classifier",
23
- description="Upload a handwritten digit and get a prediction!"
24
- )
25
 
26
- if __name__ == "__main__":
27
- iface.launch()
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
  import pickle
4
  import numpy as np
5
  from PIL import Image
 
6
 
7
+ app = FastAPI()
8
  model = pickle.load(open("mnist_model.pkl", "rb"))
9
 
10
+ def preprocess_image(file_bytes) -> np.ndarray:
11
+ # 1) Load into PIL, convert to grayscale 'L'
12
+ img = Image.open(file_bytes).convert("L")
13
+ # 2) Resize to 28×28 (use ANTIALIAS for quality)
14
+ img = img.resize((28,28), Image.ANTIALIAS)
15
+ # 3) Convert to numpy array (uint8), flatten to length-784
16
+ arr = np.array(img).astype("uint8").reshape(1, -1)
17
+ # 4) Optionally invert colors if your MNIST is white-on-black:
18
+ # arr = 255 - arr
19
+ return arr
 
 
 
 
 
 
20
 
21
+ @app.post("/predict-image/")
22
+ async def predict_image(file: UploadFile = File(...)):
23
+ # read the incoming UploadFile into BytesIO
24
+ arr = preprocess_image(file.file)
25
+ pred = model.predict(arr)[0]
26
+ return JSONResponse({"prediction": int(pred)})