BenjiELCA commited on
Commit
cbbc2ce
·
1 Parent(s): e49e1d2

models loaded from hugging face

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +7 -0
  3. modules/streamlit_utils.py +48 -20
  4. requirements.txt +1 -0
.gitignore CHANGED
@@ -27,3 +27,5 @@ BPMN_creation.ipynb
27
  *.ipynb
28
 
29
  *.pmw
 
 
 
27
  *.ipynb
28
 
29
  *.pmw
30
+
31
+ best_models.txt
app.py CHANGED
@@ -4,20 +4,27 @@ import gc
4
  import numpy as np
5
 
6
  from modules.streamlit_utils import *
 
7
 
8
 
9
  def main():
 
 
 
 
10
  st.session_state.first_run = True
11
  is_mobile, screen_width = configure_page()
12
  display_banner(is_mobile)
13
  display_title(is_mobile)
14
  display_sidebar()
 
15
  initialize_session_state()
16
 
17
  cropped_image = None
18
 
19
  img_selected = load_example_image()
20
  uploaded_file = load_user_image(img_selected, is_mobile)
 
21
  if uploaded_file is not None:
22
  cropped_image = display_image(uploaded_file, screen_width, is_mobile)
23
 
 
4
  import numpy as np
5
 
6
  from modules.streamlit_utils import *
7
+ from modules.utils import error
8
 
9
 
10
  def main():
11
+ # Example usage
12
+ if 'model_loaded' not in st.session_state:
13
+ st.session_state.model_loaded = False
14
+
15
  st.session_state.first_run = True
16
  is_mobile, screen_width = configure_page()
17
  display_banner(is_mobile)
18
  display_title(is_mobile)
19
  display_sidebar()
20
+
21
  initialize_session_state()
22
 
23
  cropped_image = None
24
 
25
  img_selected = load_example_image()
26
  uploaded_file = load_user_image(img_selected, is_mobile)
27
+
28
  if uploaded_file is not None:
29
  cropped_image = display_image(uploaded_file, screen_width, is_mobile)
30
 
modules/streamlit_utils.py CHANGED
@@ -7,6 +7,7 @@ import psutil
7
  import numpy as np
8
  from pathlib import Path
9
  import gdown
 
10
 
11
  from modules.OCR import text_prediction, filter_text, mapping_text
12
  from modules.utils import class_dict, arrow_dict, object_dict
@@ -26,6 +27,8 @@ from streamlit_image_select import image_select
26
  from streamlit_js_eval import streamlit_js_eval
27
 
28
  from modules.toWizard import create_wizard_file
 
 
29
 
30
 
31
 
@@ -48,38 +51,60 @@ def read_xml_file(filepath):
48
 
49
 
50
 
 
 
 
 
51
  # Function to load the models only once and use session state to keep track of it
52
  def load_models():
53
- with st.spinner('Loading model...'):
54
  model_object = get_faster_rcnn_model(len(object_dict))
55
- model_arrow = get_arrow_model(len(arrow_dict),2)
56
 
57
- url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA'
58
- url_object = 'https://drive.google.com/uc?id=1ciSS7H5baqXf8sRilcAjWBM9jZ0DQhsP'
59
 
60
  # Define paths to save models
61
  output_arrow = 'model_arrow.pth'
62
  output_object = 'model_object.pth'
63
 
64
- # Download models using gdown
65
- if not Path(output_arrow).exists():
66
- # Download models using gdown
67
- gdown.download(url_arrow, output_arrow, quiet=False)
68
- else:
69
- print('Model arrow downloaded from local')
70
- if not Path(output_object).exists():
71
- gdown.download(url_object, output_object, quiet=False)
72
- else:
73
- print('Model object downloaded from local')
74
-
75
  # Load models
76
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
77
- model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
78
- model_object.load_state_dict(torch.load(output_object, map_location=device))
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  st.session_state.model_loaded = True
81
- st.session_state.model_arrow = model_arrow
82
- st.session_state.model_object = model_object
83
 
84
  return model_object, model_arrow
85
 
@@ -236,9 +261,12 @@ def display_sidebar():
236
  def initialize_session_state():
237
  if 'pool_bboxes' not in st.session_state:
238
  st.session_state.pool_bboxes = []
239
- if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
 
 
240
  clear_memory()
241
  load_models()
 
242
 
243
  def load_example_image():
244
  with st.expander("Use example images"):
 
7
  import numpy as np
8
  from pathlib import Path
9
  import gdown
10
+ import os
11
 
12
  from modules.OCR import text_prediction, filter_text, mapping_text
13
  from modules.utils import class_dict, arrow_dict, object_dict
 
27
  from streamlit_js_eval import streamlit_js_eval
28
 
29
  from modules.toWizard import create_wizard_file
30
+ from huggingface_hub import hf_hub_download
31
+ import time
32
 
33
 
34
 
 
51
 
52
 
53
 
54
+
55
+ # Suppress the symlink warning
56
+ os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
57
+
58
  # Function to load the models only once and use session state to keep track of it
59
  def load_models():
60
+ with st.spinner('Loading model...'):
61
  model_object = get_faster_rcnn_model(len(object_dict))
62
+ model_arrow = get_arrow_model(len(arrow_dict), 2)
63
 
64
+ model_arrow_path = hf_hub_download(repo_id="BenjiELCA/BPMN_Detection", filename="model_arrow.pth")
65
+ model_object_path = hf_hub_download(repo_id="BenjiELCA/BPMN_Detection", filename="model_object.pth")
66
 
67
  # Define paths to save models
68
  output_arrow = 'model_arrow.pth'
69
  output_object = 'model_object.pth'
70
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Load models
72
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
73
 
74
+ # Load model arrow
75
+ if not Path(output_arrow).exists():
76
+ # Download model from Hugging Face Hub
77
+ model_arrow.load_state_dict(torch.load(model_arrow_path, map_location=device))
78
+ st.session_state.model_arrow = model_arrow
79
+ print('Model arrow downloaded from Hugging Face Hub')
80
+ # Save the model locally
81
+ torch.save(model_arrow.state_dict(), output_arrow)
82
+ elif 'model_arrow' not in st.session_state and Path(output_arrow).exists():
83
+ model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
84
+ st.session_state.model_arrow = model_arrow
85
+ print('Model arrow loaded from local file')
86
+
87
+
88
+ # Load model object
89
+ if not Path(output_object).exists():
90
+ # Download model from Hugging Face Hub
91
+ model_object.load_state_dict(torch.load(model_object_path, map_location=device))
92
+ st.session_state.model_object = model_object
93
+ print('Model object downloaded from Hugging Face Hub')
94
+ # Save the model locally
95
+ torch.save(model_object.state_dict(), output_object)
96
+ elif 'model_object' not in st.session_state and Path(output_object).exists():
97
+ model_object.load_state_dict(torch.load(output_object, map_location=device))
98
+ st.session_state.model_object = model_object
99
+ print('Model object loaded from local file')
100
+
101
+
102
+ # Move models to device
103
+ model_arrow.to(device)
104
+ model_object.to(device)
105
+
106
+ # Update session state
107
  st.session_state.model_loaded = True
 
 
108
 
109
  return model_object, model_arrow
110
 
 
261
  def initialize_session_state():
262
  if 'pool_bboxes' not in st.session_state:
263
  st.session_state.pool_bboxes = []
264
+ if 'model_loaded' not in st.session_state:
265
+ st.session_state.model_loaded = False
266
+ if not st.session_state.model_loaded:
267
  clear_memory()
268
  load_models()
269
+ st.rerun()
270
 
271
  def load_example_image():
272
  with st.expander("Use example images"):
requirements.txt CHANGED
@@ -12,4 +12,5 @@ opencv-python==4.9.0.80
12
  gdown
13
  streamlit_js_eval
14
  psutil
 
15
  git+https://github.com/Benjinoob14/streamlit-bpmn-annotation.git
 
12
  gdown
13
  streamlit_js_eval
14
  psutil
15
+ huggingface_hub
16
  git+https://github.com/Benjinoob14/streamlit-bpmn-annotation.git