Henrique Schumann commited on
Commit
2b5fea7
·
unverified ·
1 Parent(s): fd89124

first commit

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. .vscode/settings.json +3 -0
  3. app.py +55 -0
  4. model.joblib +3 -0
  5. model_weights.pth +3 -0
  6. requirements.in +7 -0
  7. requirements.txt +136 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.formatting.provider": "black"
3
+ }
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import cv2
4
+ import joblib
5
+ import numpy as np
6
+ import streamlit as st
7
+ import torch
8
+ from streamlit_drawable_canvas import st_canvas
9
+
10
+ BINARY = joblib.load("model.joblib")
11
+ ML_MODEL = pickle.loads(BINARY)
12
+ ML_MODEL.load_state_dict(
13
+ torch.load("model_weights.pth", map_location=torch.device("cpu"))
14
+ )
15
+ ML_MODEL.eval()
16
+
17
+
18
+ def predict_number(img):
19
+ if img is None:
20
+ return None, None
21
+
22
+ inp = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
23
+
24
+ with torch.no_grad():
25
+ output = ML_MODEL(inp)
26
+
27
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
28
+ values, indices = torch.topk(probabilities, 5)
29
+ confidences = {f"is number {i.item()}": v.item() for i, v in zip(indices, values)}
30
+
31
+ return confidences
32
+
33
+
34
+ canvas_result = st_canvas(
35
+ fill_color="rgba(255, 165, 0, 0.3)",
36
+ stroke_width=17,
37
+ stroke_color="#000000",
38
+ background_color="#ffffff",
39
+ background_image=None,
40
+ update_streamlit=True,
41
+ height=200,
42
+ width=200,
43
+ drawing_mode="freedraw",
44
+ key="canvas",
45
+ )
46
+
47
+ if canvas_result.image_data is not None:
48
+ image_data = canvas_result.image_data[:, :, 0]
49
+ image_data = np.squeeze(image_data)
50
+ image_data = cv2.blur(image_data, (10, 10))
51
+ image_data = cv2.resize(image_data, (28, 28))
52
+ image_data = cv2.bitwise_not(image_data)
53
+ st.image(image_data)
54
+ confidences = predict_number(image_data)
55
+ st.write(confidences)
model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc1baedcc0f4701818c3e56565ab3ec5e358d87634467c4746d65d826cfae3f7
3
+ size 13104323
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a478efd0f0d88c27a8bf9f6b1fcf33c0716a91622aea62e53be93247670766a
3
+ size 13101515
requirements.in ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ joblib
4
+ streamlit
5
+ streamlit-drawable-canvas
6
+ opencv-python
7
+ cloudpickle
requirements.txt ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.10
3
+ # by the following command:
4
+ #
5
+ # pip-compile
6
+ #
7
+ altair==4.2.2
8
+ # via streamlit
9
+ attrs==22.2.0
10
+ # via jsonschema
11
+ blinker==1.5
12
+ # via streamlit
13
+ cachetools==5.3.0
14
+ # via streamlit
15
+ certifi==2022.12.7
16
+ # via requests
17
+ charset-normalizer==3.0.1
18
+ # via requests
19
+ click==8.1.3
20
+ # via streamlit
21
+ cloudpickle==2.2.1
22
+ # via -r requirements.in
23
+ decorator==5.1.1
24
+ # via validators
25
+ entrypoints==0.4
26
+ # via altair
27
+ gitdb==4.0.10
28
+ # via gitpython
29
+ gitpython==3.1.31
30
+ # via streamlit
31
+ idna==3.4
32
+ # via requests
33
+ importlib-metadata==6.0.0
34
+ # via streamlit
35
+ jinja2==3.1.2
36
+ # via
37
+ # altair
38
+ # pydeck
39
+ joblib==1.2.0
40
+ # via -r requirements.in
41
+ jsonschema==4.17.3
42
+ # via altair
43
+ markdown-it-py==2.2.0
44
+ # via rich
45
+ markupsafe==2.1.2
46
+ # via jinja2
47
+ mdurl==0.1.2
48
+ # via markdown-it-py
49
+ numpy==1.24.2
50
+ # via
51
+ # altair
52
+ # opencv-python
53
+ # pandas
54
+ # pyarrow
55
+ # pydeck
56
+ # streamlit
57
+ # streamlit-drawable-canvas
58
+ # torchvision
59
+ opencv-python==4.7.0.72
60
+ # via -r requirements.in
61
+ packaging==23.0
62
+ # via streamlit
63
+ pandas==1.5.3
64
+ # via
65
+ # altair
66
+ # streamlit
67
+ pillow==9.4.0
68
+ # via
69
+ # streamlit
70
+ # streamlit-drawable-canvas
71
+ # torchvision
72
+ protobuf==3.20.3
73
+ # via streamlit
74
+ pyarrow==11.0.0
75
+ # via streamlit
76
+ pydeck==0.8.0
77
+ # via streamlit
78
+ pygments==2.14.0
79
+ # via rich
80
+ pympler==1.0.1
81
+ # via streamlit
82
+ pyrsistent==0.19.3
83
+ # via jsonschema
84
+ python-dateutil==2.8.2
85
+ # via
86
+ # pandas
87
+ # streamlit
88
+ pytz==2022.7.1
89
+ # via pandas
90
+ pytz-deprecation-shim==0.1.0.post0
91
+ # via tzlocal
92
+ requests==2.28.2
93
+ # via
94
+ # streamlit
95
+ # torchvision
96
+ rich==13.3.1
97
+ # via streamlit
98
+ semver==2.13.0
99
+ # via streamlit
100
+ six==1.16.0
101
+ # via python-dateutil
102
+ smmap==5.0.0
103
+ # via gitdb
104
+ streamlit==1.19.0
105
+ # via
106
+ # -r requirements.in
107
+ # streamlit-drawable-canvas
108
+ streamlit-drawable-canvas==0.9.2
109
+ # via -r requirements.in
110
+ toml==0.10.2
111
+ # via streamlit
112
+ toolz==0.12.0
113
+ # via altair
114
+ torch==1.13.1
115
+ # via
116
+ # -r requirements.in
117
+ # torchvision
118
+ torchvision==0.14.1
119
+ # via -r requirements.in
120
+ tornado==6.2
121
+ # via streamlit
122
+ typing-extensions==4.5.0
123
+ # via
124
+ # streamlit
125
+ # torch
126
+ # torchvision
127
+ tzdata==2022.7
128
+ # via pytz-deprecation-shim
129
+ tzlocal==4.2
130
+ # via streamlit
131
+ urllib3==1.26.14
132
+ # via requests
133
+ validators==0.20.0
134
+ # via streamlit
135
+ zipp==3.15.0
136
+ # via importlib-metadata