mnist-streamlit / app.py
Henrique Schumann
first commit
2b5fea7 unverified
raw
history blame
1.43 kB
import pickle
import cv2
import joblib
import numpy as np
import streamlit as st
import torch
from streamlit_drawable_canvas import st_canvas
BINARY = joblib.load("model.joblib")
ML_MODEL = pickle.loads(BINARY)
ML_MODEL.load_state_dict(
torch.load("model_weights.pth", map_location=torch.device("cpu"))
)
ML_MODEL.eval()
def predict_number(img):
if img is None:
return None, None
inp = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
with torch.no_grad():
output = ML_MODEL(inp)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
values, indices = torch.topk(probabilities, 5)
confidences = {f"is number {i.item()}": v.item() for i, v in zip(indices, values)}
return confidences
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=17,
stroke_color="#000000",
background_color="#ffffff",
background_image=None,
update_streamlit=True,
height=200,
width=200,
drawing_mode="freedraw",
key="canvas",
)
if canvas_result.image_data is not None:
image_data = canvas_result.image_data[:, :, 0]
image_data = np.squeeze(image_data)
image_data = cv2.blur(image_data, (10, 10))
image_data = cv2.resize(image_data, (28, 28))
image_data = cv2.bitwise_not(image_data)
st.image(image_data)
confidences = predict_number(image_data)
st.write(confidences)