Spaces:
Sleeping
Sleeping
models loaded from hugging face
Browse files- .gitignore +2 -0
- app.py +7 -0
- modules/streamlit_utils.py +48 -20
- requirements.txt +1 -0
.gitignore
CHANGED
@@ -27,3 +27,5 @@ BPMN_creation.ipynb
|
|
27 |
*.ipynb
|
28 |
|
29 |
*.pmw
|
|
|
|
|
|
27 |
*.ipynb
|
28 |
|
29 |
*.pmw
|
30 |
+
|
31 |
+
best_models.txt
|
app.py
CHANGED
@@ -4,20 +4,27 @@ import gc
|
|
4 |
import numpy as np
|
5 |
|
6 |
from modules.streamlit_utils import *
|
|
|
7 |
|
8 |
|
9 |
def main():
|
|
|
|
|
|
|
|
|
10 |
st.session_state.first_run = True
|
11 |
is_mobile, screen_width = configure_page()
|
12 |
display_banner(is_mobile)
|
13 |
display_title(is_mobile)
|
14 |
display_sidebar()
|
|
|
15 |
initialize_session_state()
|
16 |
|
17 |
cropped_image = None
|
18 |
|
19 |
img_selected = load_example_image()
|
20 |
uploaded_file = load_user_image(img_selected, is_mobile)
|
|
|
21 |
if uploaded_file is not None:
|
22 |
cropped_image = display_image(uploaded_file, screen_width, is_mobile)
|
23 |
|
|
|
4 |
import numpy as np
|
5 |
|
6 |
from modules.streamlit_utils import *
|
7 |
+
from modules.utils import error
|
8 |
|
9 |
|
10 |
def main():
|
11 |
+
# Example usage
|
12 |
+
if 'model_loaded' not in st.session_state:
|
13 |
+
st.session_state.model_loaded = False
|
14 |
+
|
15 |
st.session_state.first_run = True
|
16 |
is_mobile, screen_width = configure_page()
|
17 |
display_banner(is_mobile)
|
18 |
display_title(is_mobile)
|
19 |
display_sidebar()
|
20 |
+
|
21 |
initialize_session_state()
|
22 |
|
23 |
cropped_image = None
|
24 |
|
25 |
img_selected = load_example_image()
|
26 |
uploaded_file = load_user_image(img_selected, is_mobile)
|
27 |
+
|
28 |
if uploaded_file is not None:
|
29 |
cropped_image = display_image(uploaded_file, screen_width, is_mobile)
|
30 |
|
modules/streamlit_utils.py
CHANGED
@@ -7,6 +7,7 @@ import psutil
|
|
7 |
import numpy as np
|
8 |
from pathlib import Path
|
9 |
import gdown
|
|
|
10 |
|
11 |
from modules.OCR import text_prediction, filter_text, mapping_text
|
12 |
from modules.utils import class_dict, arrow_dict, object_dict
|
@@ -26,6 +27,8 @@ from streamlit_image_select import image_select
|
|
26 |
from streamlit_js_eval import streamlit_js_eval
|
27 |
|
28 |
from modules.toWizard import create_wizard_file
|
|
|
|
|
29 |
|
30 |
|
31 |
|
@@ -48,38 +51,60 @@ def read_xml_file(filepath):
|
|
48 |
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
51 |
# Function to load the models only once and use session state to keep track of it
|
52 |
def load_models():
|
53 |
-
with st.spinner('Loading model...'):
|
54 |
model_object = get_faster_rcnn_model(len(object_dict))
|
55 |
-
model_arrow = get_arrow_model(len(arrow_dict),2)
|
56 |
|
57 |
-
|
58 |
-
|
59 |
|
60 |
# Define paths to save models
|
61 |
output_arrow = 'model_arrow.pth'
|
62 |
output_object = 'model_object.pth'
|
63 |
|
64 |
-
# Download models using gdown
|
65 |
-
if not Path(output_arrow).exists():
|
66 |
-
# Download models using gdown
|
67 |
-
gdown.download(url_arrow, output_arrow, quiet=False)
|
68 |
-
else:
|
69 |
-
print('Model arrow downloaded from local')
|
70 |
-
if not Path(output_object).exists():
|
71 |
-
gdown.download(url_object, output_object, quiet=False)
|
72 |
-
else:
|
73 |
-
print('Model object downloaded from local')
|
74 |
-
|
75 |
# Load models
|
76 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
77 |
-
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
|
78 |
-
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
st.session_state.model_loaded = True
|
81 |
-
st.session_state.model_arrow = model_arrow
|
82 |
-
st.session_state.model_object = model_object
|
83 |
|
84 |
return model_object, model_arrow
|
85 |
|
@@ -236,9 +261,12 @@ def display_sidebar():
|
|
236 |
def initialize_session_state():
|
237 |
if 'pool_bboxes' not in st.session_state:
|
238 |
st.session_state.pool_bboxes = []
|
239 |
-
if '
|
|
|
|
|
240 |
clear_memory()
|
241 |
load_models()
|
|
|
242 |
|
243 |
def load_example_image():
|
244 |
with st.expander("Use example images"):
|
|
|
7 |
import numpy as np
|
8 |
from pathlib import Path
|
9 |
import gdown
|
10 |
+
import os
|
11 |
|
12 |
from modules.OCR import text_prediction, filter_text, mapping_text
|
13 |
from modules.utils import class_dict, arrow_dict, object_dict
|
|
|
27 |
from streamlit_js_eval import streamlit_js_eval
|
28 |
|
29 |
from modules.toWizard import create_wizard_file
|
30 |
+
from huggingface_hub import hf_hub_download
|
31 |
+
import time
|
32 |
|
33 |
|
34 |
|
|
|
51 |
|
52 |
|
53 |
|
54 |
+
|
55 |
+
# Suppress the symlink warning
|
56 |
+
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
|
57 |
+
|
58 |
# Function to load the models only once and use session state to keep track of it
|
59 |
def load_models():
|
60 |
+
with st.spinner('Loading model...'):
|
61 |
model_object = get_faster_rcnn_model(len(object_dict))
|
62 |
+
model_arrow = get_arrow_model(len(arrow_dict), 2)
|
63 |
|
64 |
+
model_arrow_path = hf_hub_download(repo_id="BenjiELCA/BPMN_Detection", filename="model_arrow.pth")
|
65 |
+
model_object_path = hf_hub_download(repo_id="BenjiELCA/BPMN_Detection", filename="model_object.pth")
|
66 |
|
67 |
# Define paths to save models
|
68 |
output_arrow = 'model_arrow.pth'
|
69 |
output_object = 'model_object.pth'
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# Load models
|
72 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
73 |
|
74 |
+
# Load model arrow
|
75 |
+
if not Path(output_arrow).exists():
|
76 |
+
# Download model from Hugging Face Hub
|
77 |
+
model_arrow.load_state_dict(torch.load(model_arrow_path, map_location=device))
|
78 |
+
st.session_state.model_arrow = model_arrow
|
79 |
+
print('Model arrow downloaded from Hugging Face Hub')
|
80 |
+
# Save the model locally
|
81 |
+
torch.save(model_arrow.state_dict(), output_arrow)
|
82 |
+
elif 'model_arrow' not in st.session_state and Path(output_arrow).exists():
|
83 |
+
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
|
84 |
+
st.session_state.model_arrow = model_arrow
|
85 |
+
print('Model arrow loaded from local file')
|
86 |
+
|
87 |
+
|
88 |
+
# Load model object
|
89 |
+
if not Path(output_object).exists():
|
90 |
+
# Download model from Hugging Face Hub
|
91 |
+
model_object.load_state_dict(torch.load(model_object_path, map_location=device))
|
92 |
+
st.session_state.model_object = model_object
|
93 |
+
print('Model object downloaded from Hugging Face Hub')
|
94 |
+
# Save the model locally
|
95 |
+
torch.save(model_object.state_dict(), output_object)
|
96 |
+
elif 'model_object' not in st.session_state and Path(output_object).exists():
|
97 |
+
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
98 |
+
st.session_state.model_object = model_object
|
99 |
+
print('Model object loaded from local file')
|
100 |
+
|
101 |
+
|
102 |
+
# Move models to device
|
103 |
+
model_arrow.to(device)
|
104 |
+
model_object.to(device)
|
105 |
+
|
106 |
+
# Update session state
|
107 |
st.session_state.model_loaded = True
|
|
|
|
|
108 |
|
109 |
return model_object, model_arrow
|
110 |
|
|
|
261 |
def initialize_session_state():
|
262 |
if 'pool_bboxes' not in st.session_state:
|
263 |
st.session_state.pool_bboxes = []
|
264 |
+
if 'model_loaded' not in st.session_state:
|
265 |
+
st.session_state.model_loaded = False
|
266 |
+
if not st.session_state.model_loaded:
|
267 |
clear_memory()
|
268 |
load_models()
|
269 |
+
st.rerun()
|
270 |
|
271 |
def load_example_image():
|
272 |
with st.expander("Use example images"):
|
requirements.txt
CHANGED
@@ -12,4 +12,5 @@ opencv-python==4.9.0.80
|
|
12 |
gdown
|
13 |
streamlit_js_eval
|
14 |
psutil
|
|
|
15 |
git+https://github.com/Benjinoob14/streamlit-bpmn-annotation.git
|
|
|
12 |
gdown
|
13 |
streamlit_js_eval
|
14 |
psutil
|
15 |
+
huggingface_hub
|
16 |
git+https://github.com/Benjinoob14/streamlit-bpmn-annotation.git
|