BenjiELCA commited on
Commit
2d1db93
·
1 Parent(s): 9ff1243

example image proposed

Browse files
app.py CHANGED
@@ -7,6 +7,8 @@ from PIL import Image, ImageEnhance
7
  from htlm_webpage import display_bpmn_xml
8
  import gc
9
  import psutil
 
 
10
 
11
  from OCR import text_prediction, filter_text, mapping_text, rescale
12
  from train import prepare_model
@@ -22,6 +24,7 @@ from streamlit_image_comparison import image_comparison
22
  from xml.dom import minidom
23
  from streamlit_cropper import st_cropper
24
  from streamlit_drawable_canvas import st_canvas
 
25
  from utils import find_closest_object
26
  from train import get_faster_rcnn_model, get_arrow_model
27
  import gdown
@@ -43,12 +46,13 @@ def read_xml_file(filepath):
43
 
44
  # Function to modify bounding box positions based on the given sizes
45
  def modif_box_pos(pred, size):
46
- for i, (x1, y1, x2, y2) in enumerate(pred['boxes']):
 
47
  center = [(x1 + x2) / 2, (y1 + y2) / 2]
48
- label = class_dict[pred['labels'][i]]
49
  if label in size:
50
- pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
51
- return pred
52
 
53
  # Function to create a BPMN XML file from prediction results
54
  def create_XML(full_pred, text_mapping, scale):
@@ -69,7 +73,6 @@ def create_XML(full_pred, text_mapping, scale):
69
  'exclusiveGateway': (60, 60),
70
  'event': (43.2, 43.2),
71
  'parallelGateway': (60, 60),
72
- 'sequenceFlow': (180, 12),
73
  'dataObject': (48, 72),
74
  'dataStore': (72, 72),
75
  'subProcess': (144, 108),
@@ -89,7 +92,8 @@ def create_XML(full_pred, text_mapping, scale):
89
  })
90
 
91
  #modify the boxes positions
92
- full_pred = modif_box_pos(full_pred, size_elements)
 
93
 
94
  # Create BPMN collaboration element
95
  collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
@@ -144,6 +148,7 @@ def create_XML(full_pred, text_mapping, scale):
144
  pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
145
 
146
  full_pred['boxes'] = rescale(1/scale, full_pred['boxes'])
 
147
 
148
  return pretty_xml_as_string
149
 
@@ -314,8 +319,22 @@ def main():
314
  #Create the layout for the app
315
  col1, col2 = st.columns(2)
316
  with col1:
 
 
 
 
 
 
 
 
 
 
 
317
  # Create a file uploader for the user to upload an image
318
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
319
 
320
  # Display the uploaded image if the user has uploaded an image
321
  if uploaded_file is not None:
@@ -342,7 +361,7 @@ def main():
342
  st.session_state.crop_image = cropped_image
343
  with st.spinner('Processing...'):
344
  perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
345
- st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
346
  st.balloons()
347
  else:
348
  #delete the prediction
 
7
  from htlm_webpage import display_bpmn_xml
8
  import gc
9
  import psutil
10
+ import copy
11
+
12
 
13
  from OCR import text_prediction, filter_text, mapping_text, rescale
14
  from train import prepare_model
 
24
  from xml.dom import minidom
25
  from streamlit_cropper import st_cropper
26
  from streamlit_drawable_canvas import st_canvas
27
+ from streamlit_image_select import image_select
28
  from utils import find_closest_object
29
  from train import get_faster_rcnn_model, get_arrow_model
30
  import gdown
 
46
 
47
  # Function to modify bounding box positions based on the given sizes
48
  def modif_box_pos(pred, size):
49
+ modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
50
+ for i, (x1, y1, x2, y2) in enumerate(modified_pred['boxes']):
51
  center = [(x1 + x2) / 2, (y1 + y2) / 2]
52
+ label = class_dict[modified_pred['labels'][i]]
53
  if label in size:
54
+ modified_pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
55
+ return modified_pred['boxes']
56
 
57
  # Function to create a BPMN XML file from prediction results
58
  def create_XML(full_pred, text_mapping, scale):
 
73
  'exclusiveGateway': (60, 60),
74
  'event': (43.2, 43.2),
75
  'parallelGateway': (60, 60),
 
76
  'dataObject': (48, 72),
77
  'dataStore': (72, 72),
78
  'subProcess': (144, 108),
 
92
  })
93
 
94
  #modify the boxes positions
95
+ old_boxes = copy.deepcopy(full_pred)
96
+ full_pred['boxes'] = modif_box_pos(full_pred, size_elements)
97
 
98
  # Create BPMN collaboration element
99
  collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
 
148
  pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
149
 
150
  full_pred['boxes'] = rescale(1/scale, full_pred['boxes'])
151
+ full_pred['boxes'] = old_boxes
152
 
153
  return pretty_xml_as_string
154
 
 
319
  #Create the layout for the app
320
  col1, col2 = st.columns(2)
321
  with col1:
322
+ with st.expander("Use example images"):
323
+ img_selected = image_select("If you have no image and just want to test the demo, click on one of these images", ["./images/None.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg"],
324
+ captions=["None", "Example 1", "Example 2", "Example 3"], index=0, use_container_width=False, return_value="original")
325
+
326
+ if img_selected== './images/None.jpg':
327
+ print('No example image selected')
328
+ #delete the prediction
329
+ if 'prediction' in st.session_state:
330
+ del st.session_state['prediction']
331
+ img_selected = None
332
+
333
  # Create a file uploader for the user to upload an image
334
+ if img_selected is not None:
335
+ uploaded_file = img_selected
336
+ else:
337
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
338
 
339
  # Display the uploaded image if the user has uploaded an image
340
  if uploaded_file is not None:
 
361
  st.session_state.crop_image = cropped_image
362
  with st.spinner('Processing...'):
363
  perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
364
+ #st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
365
  st.balloons()
366
  else:
367
  #delete the prediction
eval.py CHANGED
@@ -239,10 +239,7 @@ def create_links(keypoints, boxes, labels, class_dict):
239
  if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
240
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
241
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
242
-
243
- print('closest1:', closest1, 'closest2:', closest2)
244
- print('point_start:', point_start, 'point_end:', point_end)
245
-
246
  if closest1 is not None and closest2 is not None:
247
  best_points.append([point_start, point_end])
248
  links.append([closest1, closest2])
 
239
  if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
240
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
241
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
242
+
 
 
 
243
  if closest1 is not None and closest2 is not None:
244
  best_points.append([point_start, point_end])
245
  links.append([closest1, closest2])
images/example1.jpg ADDED
images/example2.jpg ADDED
images/example3.jpg ADDED
images/example4.jpg ADDED
images/none.jpg ADDED