Spaces:
Sleeping
Sleeping
import pickle | |
import cv2 | |
import joblib | |
import numpy as np | |
import streamlit as st | |
import torch | |
from streamlit_drawable_canvas import st_canvas | |
BINARY = joblib.load("model.joblib") | |
ML_MODEL = pickle.loads(BINARY) | |
ML_MODEL.load_state_dict( | |
torch.load("model_weights.pth", map_location=torch.device("cpu")) | |
) | |
ML_MODEL.eval() | |
def predict_number(img): | |
if img is None: | |
return None, None | |
inp = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0 | |
with torch.no_grad(): | |
output = ML_MODEL(inp) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
values, indices = torch.topk(probabilities, 5) | |
confidences = {f"is number {i.item()}": v.item() for i, v in zip(indices, values)} | |
return confidences | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 165, 0, 0.3)", | |
stroke_width=17, | |
stroke_color="#000000", | |
background_color="#ffffff", | |
background_image=None, | |
update_streamlit=True, | |
height=200, | |
width=200, | |
drawing_mode="freedraw", | |
key="canvas", | |
) | |
if canvas_result.image_data is not None: | |
image_data = canvas_result.image_data[:, :, 0] | |
image_data = np.squeeze(image_data) | |
image_data = cv2.blur(image_data, (10, 10)) | |
image_data = cv2.resize(image_data, (28, 28)) | |
image_data = cv2.bitwise_not(image_data) | |
st.image(image_data) | |
confidences = predict_number(image_data) | |
st.write(confidences) | |