BenjiELCA commited on
Commit
615e9f1
·
1 Parent(s): 25f3264

put online demo

Browse files
Files changed (12) hide show
  1. .gitignore +13 -0
  2. OCR.py +415 -0
  3. demo_streamlit.py +339 -0
  4. display.py +181 -0
  5. eval.py +649 -0
  6. flask.py +6 -0
  7. htlm_webpage.py +141 -0
  8. packages.txt +1 -0
  9. requirements.txt +10 -0
  10. toXML.py +351 -0
  11. train.py +394 -0
  12. utils.py +936 -0
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ __pycache__/
3
+
4
+ temp/
5
+
6
+
7
+ VISION_KEY.json
8
+
9
+ *.pth
10
+
11
+ .streamlit/secrets.toml
12
+
13
+ backup/
OCR.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from azure.ai.vision.imageanalysis import ImageAnalysisClient
4
+ from azure.ai.vision.imageanalysis.models import VisualFeatures
5
+ from azure.core.credentials import AzureKeyCredential
6
+ import time
7
+ import numpy as np
8
+ import networkx as nx
9
+ from eval import iou
10
+ from utils import class_dict, proportion_inside
11
+ import json
12
+ from utils import rescale_boxes as rescale
13
+ import streamlit as st
14
+
15
+ VISION_KEY = st.secrets["VISION_KEY"]
16
+ VISION_ENDPOINT = st.secrets["VISION_ENDPOINT"]
17
+
18
+ """
19
+ #If local execution
20
+ with open("VISION_KEY.json", "r") as json_file:
21
+ json_data = json.load(json_file)
22
+
23
+ # Step 2: Parse the JSON data (this is done by json.load automatically)
24
+ VISION_KEY = json_data["VISION_KEY"]
25
+ VISION_ENDPOINT = json_data["VISION_ENDPOINT"]
26
+ """
27
+
28
+
29
+ def sample_ocr_image_file(image_data):
30
+ # Set the values of your computer vision endpoint and computer vision key
31
+ # as environment variables:
32
+ try:
33
+ endpoint = VISION_ENDPOINT
34
+ key = VISION_KEY
35
+ except KeyError:
36
+ print("Missing environment variable 'VISION_ENDPOINT' or 'VISION_KEY'")
37
+ print("Set them before running this sample.")
38
+ exit()
39
+
40
+ # Create an Image Analysis client
41
+ client = ImageAnalysisClient(
42
+ endpoint=endpoint,
43
+ credential=AzureKeyCredential(key)
44
+ )
45
+
46
+ # Extract text (OCR) from an image stream. This will be a synchronously (blocking) call.
47
+ result = client.analyze(
48
+ image_data=image_data,
49
+ visual_features=[VisualFeatures.READ]
50
+ )
51
+
52
+ return result
53
+
54
+
55
+ def text_prediction(image):
56
+ #transform the image into a byte array
57
+ image.save('temp.jpg')
58
+ with open('temp.jpg', 'rb') as f:
59
+ image_data = f.read()
60
+ ocr_result = sample_ocr_image_file(image_data)
61
+ #delete the temporary image
62
+ os.remove('temp.jpg')
63
+ return ocr_result
64
+
65
+ def filter_text(ocr_result, threshold=0.5):
66
+ words_to_cancel = {"+",".",",","#","@","!","?","(",")","[","]","{","}","<",">","/","\\","|","-","_","=","&","^","%","$","£","€","¥","¢","¤","§","©","®","™","°","±","×","÷","¶","∆","∏","∑","∞","√","∫","≈","≠","≤","≥","≡","∼"}
67
+ # Add every other one-letter word to the list of words to cancel, except 'I' and 'a'
68
+ for letter in "bcdefghjklmnopqrstuvwxyz1234567890": # All lowercase letters except 'a'
69
+ words_to_cancel.add(letter)
70
+ words_to_cancel.add("i")
71
+ words_to_cancel.add(letter.upper()) # Add the uppercase version as well
72
+ characters_to_cancel = {"+", "<", ">"} # Characters to cancel
73
+
74
+ list_of_lines = []
75
+
76
+ for block in ocr_result['readResult']['blocks']:
77
+ for line in block['lines']:
78
+ line_text = []
79
+ x_min, y_min = float('inf'), float('inf')
80
+ x_max, y_max = float('-inf'), float('-inf')
81
+ for word in line['words']:
82
+ if word['text'] in words_to_cancel or any(disallowed_char in word['text'] for disallowed_char in characters_to_cancel):
83
+ continue
84
+ if word['confidence'] > threshold:
85
+ if word['text']:
86
+ line_text.append(word['text'])
87
+ x = [point['x'] for point in word['boundingPolygon']]
88
+ y = [point['y'] for point in word['boundingPolygon']]
89
+ x_min = min(x_min, min(x))
90
+ y_min = min(y_min, min(y))
91
+ x_max = max(x_max, max(x))
92
+ y_max = max(y_max, max(y))
93
+ if line_text: # If there are valid words in the line
94
+ list_of_lines.append({
95
+ 'text': ' '.join(line_text),
96
+ 'boundingBox': [x_min,y_min,x_max,y_max]
97
+ })
98
+
99
+ list_text = []
100
+ list_bbox = []
101
+ for i in range(len(list_of_lines)):
102
+ list_text.append(list_of_lines[i]['text'])
103
+ for i in range(len(list_of_lines)):
104
+ list_bbox.append(list_of_lines[i]['boundingBox'])
105
+
106
+ list_of_lines = [list_bbox, list_text]
107
+
108
+ return list_of_lines
109
+
110
+
111
+
112
+
113
+ def get_box_points(box):
114
+ """Returns all critical points of a box: corners and midpoints of edges."""
115
+ xmin, ymin, xmax, ymax = box
116
+ return np.array([
117
+ [xmin, ymin], # Bottom-left corner
118
+ [xmax, ymin], # Bottom-right corner
119
+ [xmin, ymax], # Top-left corner
120
+ [xmax, ymax], # Top-right corner
121
+ [(xmin + xmax) / 2, ymin], # Midpoint of bottom edge
122
+ [(xmin + xmax) / 2, ymax], # Midpoint of top edge
123
+ [xmin, (ymin + ymax) / 2], # Midpoint of left edge
124
+ [xmax, (ymin + ymax) / 2] # Midpoint of right edge
125
+ ])
126
+
127
+ def min_distance_between_boxes(box1, box2):
128
+ """Computes the minimum distance between two boxes considering all critical points."""
129
+ points1 = get_box_points(box1)
130
+ points2 = get_box_points(box2)
131
+
132
+ min_dist = float('inf')
133
+ for point1 in points1:
134
+ for point2 in points2:
135
+ dist = np.linalg.norm(point1 - point2)
136
+ if dist < min_dist:
137
+ min_dist = dist
138
+ return min_dist
139
+
140
+
141
+ def is_inside(box1, box2):
142
+ """Check if the center of box1 is inside box2."""
143
+ x_center = (box1[0] + box1[2]) / 2
144
+ y_center = (box1[1] + box1[3]) / 2
145
+ return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
146
+
147
+ def are_close(box1, box2, threshold=50):
148
+ """Determines if boxes are close based on their corners and center points."""
149
+ corners1 = np.array([
150
+ [box1[0], box1[1]], [box1[0], box1[3]], [box1[2], box1[1]], [box1[2], box1[3]],
151
+ [(box1[0]+box1[2])/2, box1[1]], [(box1[0]+box1[2])/2, box1[3]],
152
+ [box1[0], (box1[1]+box1[3])/2], [box1[2], (box1[1]+box1[3])/2]
153
+ ])
154
+ corners2 = np.array([
155
+ [box2[0], box2[1]], [box2[0], box2[3]], [box2[2], box2[1]], [box2[2], box2[3]],
156
+ [(box2[0]+box2[2])/2, box2[1]], [(box2[0]+box2[2])/2, box2[3]],
157
+ [box2[0], (box2[1]+box2[3])/2], [box2[2], (box2[1]+box2[3])/2]
158
+ ])
159
+ for c1 in corners1:
160
+ for c2 in corners2:
161
+ if np.linalg.norm(c1 - c2) < threshold:
162
+ return True
163
+ return False
164
+
165
+ def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
166
+ """Find the closest box to the given text box within a specified threshold."""
167
+ min_distance = float('inf')
168
+ closest_index = None
169
+
170
+ #check if the text is inside a sequenceFlow
171
+ for j in range(len(all_boxes)):
172
+ if proportion_inside(text_box, all_boxes[j])>iou_threshold and labels[j] == list(class_dict.values()).index('sequenceFlow'):
173
+ return j
174
+
175
+ for i, box in enumerate(all_boxes):
176
+ # Compute the center of both boxes
177
+ center_text = np.array([(text_box[0] + text_box[2]) / 2, (text_box[1] + text_box[3]) / 2])
178
+ center_box = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2])
179
+
180
+ # Calculate Euclidean distance between centers
181
+ distance = np.linalg.norm(center_text - center_box)
182
+
183
+ # Update closest box if this box is nearer
184
+ if distance < min_distance:
185
+ min_distance = distance
186
+ closest_index = i
187
+
188
+ # Check if the closest box found is within the acceptable threshold
189
+ if min_distance < threshold:
190
+ return closest_index
191
+
192
+ return None
193
+
194
+
195
+ def is_vertical(box):
196
+ """Determine if the text in the bounding box is vertically aligned."""
197
+ width = box[2] - box[0]
198
+ height = box[3] - box[1]
199
+ return (height > 2*width)
200
+
201
+ def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
202
+ """Maps text boxes to task boxes and groups texts within each task based on proximity."""
203
+ G = nx.Graph()
204
+
205
+ # Map each text box to the nearest task box
206
+ task_to_texts = {i: [] for i in range(len(task_boxes))}
207
+ information_texts = [] # texts not inside any task box
208
+ text_to_task_mapped = [False] * len(text_boxes)
209
+
210
+ for idx, text_box in enumerate(text_boxes):
211
+ mapped = False
212
+ for jdx, task_box in enumerate(task_boxes):
213
+ if proportion_inside(text_box, task_box)>iou_threshold:
214
+ task_to_texts[jdx].append(idx)
215
+ text_to_task_mapped[idx] = True
216
+ mapped = True
217
+ break
218
+ if not mapped:
219
+ information_texts.append(idx)
220
+
221
+ all_grouped_texts = []
222
+ sentence_boxes = [] # Store the bounding box for each sentence
223
+
224
+ # Process texts for each task
225
+ for task_texts in task_to_texts.values():
226
+ G.clear()
227
+ for i in task_texts:
228
+ G.add_node(i)
229
+ for j in task_texts:
230
+ if i != j and are_close(text_boxes[i], text_boxes[j]) and not is_vertical(text_boxes[i]) and not is_vertical(text_boxes[j]):
231
+ G.add_edge(i, j)
232
+
233
+ groups = list(nx.connected_components(G))
234
+ for group in groups:
235
+ group = list(group)
236
+ lines = {}
237
+ for idx in group:
238
+ y_center = (text_boxes[idx][1] + text_boxes[idx][3]) / 2
239
+ found_line = False
240
+ for line in lines:
241
+ if abs(y_center - line) < (text_boxes[idx][3] - text_boxes[idx][1]) / 2:
242
+ lines[line].append(idx)
243
+ found_line = True
244
+ break
245
+ if not found_line:
246
+ lines[y_center] = [idx]
247
+
248
+ sorted_lines = sorted(lines.keys())
249
+ grouped_texts = []
250
+ min_x = min_y = float('inf')
251
+ max_x = max_y = -float('inf')
252
+
253
+ for line in sorted_lines:
254
+ sorted_indices = sorted(lines[line], key=lambda idx: text_boxes[idx][0])
255
+ line_text = ' '.join(texts[idx] for idx in sorted_indices)
256
+ grouped_texts.append(line_text)
257
+
258
+ for idx in sorted_indices:
259
+ box = text_boxes[idx]
260
+ min_x = min(min_x-5, box[0]-5)
261
+ min_y = min(min_y-5, box[1]-5)
262
+ max_x = max(max_x+5, box[2]+5)
263
+ max_y = max(max_y+5, box[3]+5)
264
+
265
+ all_grouped_texts.append(' '.join(grouped_texts))
266
+ sentence_boxes.append([min_x, min_y, max_x, max_y])
267
+
268
+ # Group information texts
269
+ G.clear()
270
+ info_sentence_boxes = []
271
+
272
+ for i in information_texts:
273
+ G.add_node(i)
274
+ for j in information_texts:
275
+ if i != j and are_close(text_boxes[i], text_boxes[j], percentage_thresh * min_dist) and not is_vertical(text_boxes[i]) and not is_vertical(text_boxes[j]):
276
+ G.add_edge(i, j)
277
+
278
+ info_groups = list(nx.connected_components(G))
279
+ information_grouped_texts = []
280
+ for group in info_groups:
281
+ group = list(group)
282
+ lines = {}
283
+ for idx in group:
284
+ y_center = (text_boxes[idx][1] + text_boxes[idx][3]) / 2
285
+ found_line = False
286
+ for line in lines:
287
+ if abs(y_center - line) < (text_boxes[idx][3] - text_boxes[idx][1]) / 2:
288
+ lines[line].append(idx)
289
+ found_line = True
290
+ break
291
+ if not found_line:
292
+ lines[y_center] = [idx]
293
+
294
+ sorted_lines = sorted(lines.keys())
295
+ grouped_texts = []
296
+ min_x = min_y = float('inf')
297
+ max_x = max_y = -float('inf')
298
+
299
+ for line in sorted_lines:
300
+ sorted_indices = sorted(lines[line], key=lambda idx: text_boxes[idx][0])
301
+ line_text = ' '.join(texts[idx] for idx in sorted_indices)
302
+ grouped_texts.append(line_text)
303
+
304
+ for idx in sorted_indices:
305
+ box = text_boxes[idx]
306
+ min_x = min(min_x, box[0])
307
+ min_y = min(min_y, box[1])
308
+ max_x = max(max_x, box[2])
309
+ max_y = max(max_y, box[3])
310
+
311
+ information_grouped_texts.append(' '.join(grouped_texts))
312
+ info_sentence_boxes.append([min_x, min_y, max_x, max_y])
313
+
314
+ return all_grouped_texts, sentence_boxes, information_grouped_texts, info_sentence_boxes
315
+
316
+
317
+ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
318
+
319
+ ########### REFAIRE CETTE FONCTION ###########
320
+ #refaire la fonction pour qu'elle prenne en premier les elements qui sont dans les task et ensuite prendre un seuil de distance pour les autres elements
321
+ #ou sinon faire la distance entre les elements et non pas seulement les tasks
322
+
323
+
324
+ # Example usage
325
+ boxes = rescale(scale, full_pred['boxes'])
326
+
327
+ min_dist = 200
328
+ labels = full_pred['labels']
329
+ avoid = [list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation')]
330
+ for i in range(len(boxes)):
331
+ box1 = boxes[i]
332
+ if labels[i] in avoid:
333
+ continue
334
+ for j in range(i + 1, len(boxes)):
335
+ box2 = boxes[j]
336
+ if labels[j] in avoid:
337
+ continue
338
+ dist = min_distance_between_boxes(box1, box2)
339
+ min_dist = min(min_dist, dist)
340
+
341
+ #print("Minimum distance between boxes:", min_dist)
342
+
343
+ text_pred[0] = rescale(scale, text_pred[0])
344
+ task_boxes = [box for i, box in enumerate(boxes) if full_pred['labels'][i] == list(class_dict.values()).index('task')]
345
+ grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_pred[0], text_pred[1], min_dist=min_dist)
346
+ BPMN_id = set(full_pred['BPMN_id']) # This ensures uniqueness of task names
347
+ text_mapping = {id: '' for id in BPMN_id}
348
+
349
+
350
+ if print_sentences:
351
+ for sentence, box in zip(grouped_sentences, sentence_bounding_boxes):
352
+ print("Task-related Text:", sentence)
353
+ print("Bounding Box:", box)
354
+ print("Information Texts:", info_texts)
355
+ print("Information Bounding Boxes:", info_boxes)
356
+
357
+ # Map the grouped sentences to the corresponding task
358
+ for i in range(len(sentence_bounding_boxes)):
359
+ for j in range(len(boxes)):
360
+ if proportion_inside(sentence_bounding_boxes[i], boxes[j])>iou_threshold and full_pred['labels'][j] == list(class_dict.values()).index('task'):
361
+ text_mapping[full_pred['BPMN_id'][j]]=grouped_sentences[i]
362
+
363
+ # Map the grouped sentences to the corresponding pool
364
+ for i in range(len(info_boxes)):
365
+ if is_vertical(info_boxes[i]):
366
+ for j in range(len(boxes)):
367
+ if proportion_inside(info_boxes[i], boxes[j])>0 and full_pred['labels'][j] == list(class_dict.values()).index('pool'):
368
+ print("Text:", info_texts[i], "associate with ", full_pred['BPMN_id'][j])
369
+ bpmn_id = full_pred['BPMN_id'][j]
370
+ # Append new text or create new entry if not existing
371
+ if bpmn_id in text_mapping:
372
+ text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
373
+ else:
374
+ text_mapping[bpmn_id] = info_texts[i]
375
+ info_texts[i] = '' # Clear the text to avoid re-use
376
+
377
+ # Map the grouped sentences to the corresponding object
378
+ for i in range(len(info_boxes)):
379
+ if is_vertical(info_boxes[i]):
380
+ continue # Skip if the text is vertical
381
+ for j in range(len(boxes)):
382
+ if info_texts[i] == '':
383
+ continue # Skip if there's no text
384
+ if (proportion_inside(info_boxes[i], boxes[j])>0 or are_close(info_boxes[i], boxes[j], threshold=percentage_thresh*min_dist)) and (full_pred['labels'][j] == list(class_dict.values()).index('event')
385
+ or full_pred['labels'][j] == list(class_dict.values()).index('messageEvent')
386
+ or full_pred['labels'][j] == list(class_dict.values()).index('timerEvent')
387
+ or full_pred['labels'][j] == list(class_dict.values()).index('dataObject')) :
388
+ bpmn_id = full_pred['BPMN_id'][j]
389
+ # Append new text or create new entry if not existing
390
+ if bpmn_id in text_mapping:
391
+ text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
392
+ else:
393
+ text_mapping[bpmn_id] = info_texts[i]
394
+ info_texts[i] = '' # Clear the text to avoid re-use
395
+
396
+ # Map the grouped sentences to the corresponding flow
397
+ for i in range(len(info_boxes)):
398
+ if info_texts[i] == '' or is_vertical(info_boxes[i]):
399
+ continue # Skip if there's no text
400
+ # Find the closest box within the defined threshold
401
+ closest_index = find_closest_box(info_boxes[i], boxes, full_pred['labels'], threshold=4*min_dist)
402
+ if closest_index is not None and (full_pred['labels'][closest_index] == list(class_dict.values()).index('sequenceFlow') or full_pred['labels'][closest_index] == list(class_dict.values()).index('messageFlow')):
403
+ bpmn_id = full_pred['BPMN_id'][closest_index]
404
+ # Append new text or create new entry if not existing
405
+ if bpmn_id in text_mapping:
406
+ text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
407
+ else:
408
+ text_mapping[bpmn_id] = info_texts[i]
409
+ info_texts[i] = '' # Clear the text to avoid re-use
410
+
411
+ if print_sentences:
412
+ print("Text Mapping:", text_mapping)
413
+ print("Information Texts left:", info_texts)
414
+
415
+ return text_mapping
demo_streamlit.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision.transforms import functional as F
6
+ from PIL import Image, ImageEnhance
7
+ from htlm_webpage import display_bpmn_xml
8
+ import gc
9
+ import psutil
10
+
11
+ from OCR import text_prediction, filter_text, mapping_text, rescale
12
+ from train import prepare_model
13
+ from utils import draw_annotations, create_loader, class_dict, arrow_dict, object_dict
14
+ from toXML import calculate_pool_bounds, add_diagram_elements
15
+ from pathlib import Path
16
+ from toXML import create_bpmn_object, create_flow_element
17
+ import xml.etree.ElementTree as ET
18
+ import numpy as np
19
+ from display import draw_stream
20
+ from eval import full_prediction
21
+ from streamlit_image_comparison import image_comparison
22
+ from xml.dom import minidom
23
+ from streamlit_cropper import st_cropper
24
+ from streamlit_drawable_canvas import st_canvas
25
+ from utils import find_closest_object
26
+ from train import get_faster_rcnn_model, get_arrow_model
27
+ import gdown
28
+
29
+ def get_memory_usage():
30
+ process = psutil.Process()
31
+ mem_info = process.memory_info()
32
+ return mem_info.rss / (1024 ** 2) # Return memory usage in MB
33
+
34
+ def clear_memory():
35
+ st.session_state.clear()
36
+ gc.collect()
37
+
38
+ # Function to read XML content from a file
39
+ def read_xml_file(filepath):
40
+ """ Read XML content from a file """
41
+ with open(filepath, 'r', encoding='utf-8') as file:
42
+ return file.read()
43
+
44
+ # Function to modify bounding box positions based on the given sizes
45
+ def modif_box_pos(pred, size):
46
+ for i, (x1, y1, x2, y2) in enumerate(pred['boxes']):
47
+ center = [(x1 + x2) / 2, (y1 + y2) / 2]
48
+ label = class_dict[pred['labels'][i]]
49
+ if label in size:
50
+ pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
51
+ return pred
52
+
53
+ # Function to create a BPMN XML file from prediction results
54
+ def create_XML(full_pred, text_mapping, scale):
55
+ namespaces = {
56
+ 'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
57
+ 'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
58
+ 'di': 'http://www.omg.org/spec/DD/20100524/DI',
59
+ 'dc': 'http://www.omg.org/spec/DD/20100524/DC',
60
+ 'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
61
+ }
62
+
63
+ size_elements = {
64
+ 'start': (54, 54),
65
+ 'task': (150, 120),
66
+ 'message': (54, 54),
67
+ 'messageEvent': (54, 54),
68
+ 'end': (54, 54),
69
+ 'exclusiveGateway': (75, 75),
70
+ 'event': (54, 54),
71
+ 'parallelGateway': (75, 75),
72
+ 'sequenceFlow': (225, 15),
73
+ 'pool': (375, 150),
74
+ 'lane': (300, 150),
75
+ 'dataObject': (60, 90),
76
+ 'dataStore': (90, 90),
77
+ 'subProcess': (180, 135),
78
+ 'eventBasedGateway': (75, 75),
79
+ 'timerEvent': (60, 60),
80
+ }
81
+
82
+
83
+ definitions = ET.Element('bpmn:definitions', {
84
+ 'xmlns:xsi': namespaces['xsi'],
85
+ 'xmlns:bpmn': namespaces['bpmn'],
86
+ 'xmlns:bpmndi': namespaces['bpmndi'],
87
+ 'xmlns:di': namespaces['di'],
88
+ 'xmlns:dc': namespaces['dc'],
89
+ 'targetNamespace': "http://example.bpmn.com",
90
+ 'id': "simpleExample"
91
+ })
92
+
93
+ # Create BPMN collaboration element
94
+ collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
95
+
96
+ # Create BPMN process elements
97
+ process = []
98
+ for idx in range(len(full_pred['pool_dict'].items())):
99
+ process_id = f'process_{idx+1}'
100
+ process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]))
101
+
102
+ bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
103
+ bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
104
+
105
+ full_pred['boxes'] = rescale(scale, full_pred['boxes'])
106
+
107
+ # Add diagram elements for each pool
108
+ for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
109
+ pool_id = f'participant_{idx+1}'
110
+ pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
111
+
112
+ # Calculate the bounding box for the pool
113
+ if len(keep_elements) == 0:
114
+ min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
115
+ pool_width = max_x - min_x
116
+ pool_height = max_y - min_y
117
+ else:
118
+ min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
119
+ pool_width = max_x - min_x + 100 # Adding padding
120
+ pool_height = max_y - min_y + 100 # Adding padding
121
+
122
+ add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
123
+
124
+ # Create BPMN elements for each pool
125
+ for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
126
+ create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
127
+
128
+ # Create message flow elements
129
+ message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow']
130
+ for idx in message_flows:
131
+ create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True)
132
+
133
+ # Create sequence flow elements
134
+ for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
135
+ for i in keep_elements:
136
+ if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
137
+ create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
138
+
139
+ # Generate pretty XML string
140
+ tree = ET.ElementTree(definitions)
141
+ rough_string = ET.tostring(definitions, 'utf-8')
142
+ reparsed = minidom.parseString(rough_string)
143
+ pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
144
+
145
+ full_pred['boxes'] = rescale(1/scale, full_pred['boxes'])
146
+
147
+ return pretty_xml_as_string
148
+
149
+
150
+ # Function to load the models only once and use session state to keep track of it
151
+ def load_models():
152
+ with st.spinner('Loading model...'):
153
+ model_object = get_faster_rcnn_model(len(object_dict))
154
+ model_arrow = get_arrow_model(len(arrow_dict),2)
155
+
156
+ url_arrow = 'https://drive.google.com/uc?id=1xwfvo7BgDWz-1jAiJC1DCF0Wp8YlFNWt'
157
+ url_object = 'https://drive.google.com/uc?id=1GiM8xOXG6M6r8J9HTOeMJz9NKu7iumZi'
158
+
159
+ # Define paths to save models
160
+ output_arrow = 'model_arrow.pth'
161
+ output_object = 'model_object.pth'
162
+
163
+ # Download models using gdown
164
+ if not Path(output_arrow).exists():
165
+ # Download models using gdown
166
+ gdown.download(url_arrow, output_arrow, quiet=False)
167
+ else:
168
+ print('Model arrow downloaded from local')
169
+ if not Path(output_object).exists():
170
+ gdown.download(url_object, output_object, quiet=False)
171
+ else:
172
+ print('Model object downloaded from local')
173
+
174
+ # Load models
175
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
176
+ model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
177
+ model_object.load_state_dict(torch.load(output_object, map_location=device))
178
+ st.session_state.model_loaded = True
179
+ st.session_state.model_arrow = model_arrow
180
+ st.session_state.model_object = model_object
181
+
182
+ # Function to prepare the image for processing
183
+ def prepare_image(image, pad=True, new_size=(1333, 1333)):
184
+ original_size = image.size
185
+ # Calculate scale to fit the new size while maintaining aspect ratio
186
+ scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
187
+ new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
188
+ # Resize image to new scaled size
189
+ image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
190
+
191
+ if pad:
192
+ enhancer = ImageEnhance.Brightness(image)
193
+ image = enhancer.enhance(1.5) # Adjust the brightness if necessary
194
+ # Pad the resized image to make it exactly the desired size
195
+ padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
196
+ image = F.pad(image, padding, fill=200, padding_mode='edge')
197
+
198
+ return new_scaled_size, image
199
+
200
+ # Function to display various options for image annotation
201
+ def display_options(image, score_threshold):
202
+ col1, col2, col3, col4, col5 = st.columns(5)
203
+ with col1:
204
+ write_class = st.toggle("Write Class", value=True)
205
+ draw_keypoints = st.toggle("Draw Keypoints", value=True)
206
+ draw_boxes = st.toggle("Draw Boxes", value=True)
207
+ with col2:
208
+ draw_text = st.toggle("Draw Text", value=False)
209
+ write_text = st.toggle("Write Text", value=False)
210
+ draw_links = st.toggle("Draw Links", value=False)
211
+ with col3:
212
+ write_score = st.toggle("Write Score", value=True)
213
+ write_idx = st.toggle("Write Index", value=False)
214
+ with col4:
215
+ # Define options for the dropdown menu
216
+ dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))]
217
+ dropdown_options[0] = 'all'
218
+ selected_option = st.selectbox("Show class", dropdown_options)
219
+
220
+ # Draw the annotated image with selected options
221
+ annotated_image = draw_stream(
222
+ np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred,
223
+ draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
224
+ write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_print=selected_option,
225
+ score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
226
+ )
227
+
228
+ # Display the original and annotated images side by side
229
+ image_comparison(
230
+ img1=annotated_image,
231
+ img2=image,
232
+ label1="Annotated Image",
233
+ label2="Original Image",
234
+ starting_position=99,
235
+ width=1000,
236
+ )
237
+
238
+ # Function to perform inference on the uploaded image using the loaded models
239
+ def perform_inference(model_object, model_arrow, image, score_threshold):
240
+ _, uploaded_image = prepare_image(image, pad=False)
241
+
242
+ img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
243
+
244
+ # Display original image
245
+ if 'image_placeholder' not in st.session_state:
246
+ image_placeholder = st.empty() # Create an empty placeholder
247
+ image_placeholder.image(uploaded_image, caption='Original Image', width=1000)
248
+
249
+ # Prediction
250
+ _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5)
251
+
252
+ # Perform OCR on the uploaded image
253
+ ocr_results = text_prediction(uploaded_image)
254
+
255
+ # Filter and map OCR results to prediction results
256
+ st.session_state.text_pred = filter_text(ocr_results, threshold=0.5)
257
+ st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=0.5)
258
+
259
+ # Remove the original image display
260
+ image_placeholder.empty()
261
+
262
+ # Force garbage collection
263
+ gc.collect()
264
+
265
+ @st.cache_data
266
+ def get_image(uploaded_file):
267
+ return Image.open(uploaded_file).convert('RGB')
268
+
269
+ def main():
270
+ st.set_page_config(layout="wide")
271
+ st.title("BPMN model recognition demo")
272
+
273
+ # Display current memory usage
274
+ memory_usage = get_memory_usage()
275
+ print(f"Current memory usage: {memory_usage:.2f} MB")
276
+
277
+ # Initialize the session state for storing pool bounding boxes
278
+ if 'pool_bboxes' not in st.session_state:
279
+ st.session_state.pool_bboxes = []
280
+
281
+ # Load the models using the defined function
282
+ if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
283
+ clear_memory()
284
+ load_models()
285
+
286
+ model_arrow = st.session_state.model_arrow
287
+ model_object = st.session_state.model_object
288
+
289
+ #Create the layout for the app
290
+ col1, col2 = st.columns(2)
291
+ with col1:
292
+ # Create a file uploader for the user to upload an image
293
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
294
+
295
+ # Display the uploaded image if the user has uploaded an image
296
+ if uploaded_file is not None:
297
+ original_image = get_image(uploaded_file)
298
+ col1, col2 = st.columns(2)
299
+
300
+ # Create a cropper to allow the user to crop the image and display the cropped image
301
+ with col1:
302
+ cropped_image = st_cropper(original_image, realtime_update=True, box_color='#0000FF', should_resize_image=True, default_coords=(30, original_image.size[0]-30, 30, original_image.size[1]-30))
303
+ with col2:
304
+ st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=500)
305
+
306
+ # Display the options for the user to set the score threshold and scale
307
+ if cropped_image is not None:
308
+ col1, col2, col3 = st.columns(3)
309
+ with col1:
310
+ score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
311
+ with col2:
312
+ st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
313
+
314
+ # Launch the prediction when the user clicks the button
315
+ if st.button("Launch Prediction"):
316
+ st.session_state.crop_image = cropped_image
317
+ with st.spinner('Processing...'):
318
+ perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
319
+ st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
320
+
321
+ print('Detection completed!')
322
+
323
+
324
+ # If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
325
+ if 'prediction' in st.session_state and uploaded_file is not None:
326
+ st.success('Detection completed!')
327
+ display_options(st.session_state.crop_image, score_threshold)
328
+
329
+ #if st.session_state.prediction_up==True:
330
+ st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
331
+
332
+ display_bpmn_xml(st.session_state.bpmn_xml)
333
+
334
+ # Force garbage collection after display
335
+ gc.collect()
336
+
337
+ if __name__ == "__main__":
338
+ print('Starting the app...')
339
+ main()
display.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import draw_annotations, create_loader, class_dict, resize_boxes, resize_keypoints, find_other_keypoint
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from OCR import group_texts
7
+
8
+
9
+
10
+
11
+ def draw_stream(image,
12
+ prediction=None,
13
+ text_predictions=None,
14
+ class_dict=class_dict,
15
+ draw_keypoints=False,
16
+ draw_boxes=False,
17
+ draw_text=False,
18
+ draw_links=False,
19
+ draw_twins=False,
20
+ draw_grouped_text=False,
21
+ write_class=False,
22
+ write_score=False,
23
+ write_text=False,
24
+ score_threshold=0.4,
25
+ write_idx=False,
26
+ keypoints_correction=False,
27
+ new_size=(1333, 1333),
28
+ only_print=None,
29
+ axis=False,
30
+ return_image=False,
31
+ resize=False):
32
+ """
33
+ Draws annotations on images including bounding boxes, keypoints, links, and text.
34
+
35
+ Parameters:
36
+ - image (np.array): The image on which annotations will be drawn.
37
+ - target (dict): Ground truth data containing boxes, labels, etc.
38
+ - prediction (dict): Prediction data from a model.
39
+ - full_prediction (dict): Additional detailed prediction data, potentially including relationships.
40
+ - text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
41
+ - class_dict (dict): Mapping from class IDs to class names.
42
+ - draw_keypoints (bool): Flag to draw keypoints.
43
+ - draw_boxes (bool): Flag to draw bounding boxes.
44
+ - draw_text (bool): Flag to draw text annotations.
45
+ - draw_links (bool): Flag to draw links between annotations.
46
+ - draw_twins (bool): Flag to draw twins keypoints.
47
+ - write_class (bool): Flag to write class names near the annotations.
48
+ - write_score (bool): Flag to write scores near the annotations.
49
+ - write_text (bool): Flag to write OCR recognized text.
50
+ - score_threshold (float): Threshold for scores above which annotations will be drawn.
51
+ - only_print (str): Specific class name to filter annotations by.
52
+ - resize (bool): Whether to resize annotations to fit the image size.
53
+ """
54
+
55
+ # Convert image to RGB (if not already in that format)
56
+ if prediction is None:
57
+ image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
58
+
59
+
60
+ image_copy = image.copy()
61
+ scale = max(image.shape[0], image.shape[1]) / 1000
62
+
63
+ original_size = (image.shape[0], image.shape[1])
64
+ # Calculate scale to fit the new size while maintaining aspect ratio
65
+ scale_ = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
66
+ new_scaled_size = (int(original_size[0] * scale_), int(original_size[1] * scale_))
67
+
68
+ for i in range(len(prediction['boxes'])):
69
+ box = prediction['boxes'][i]
70
+ x1, y1, x2, y2 = box
71
+ if resize:
72
+ x1, y1, x2, y2 = resize_boxes(np.array([box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
73
+ score = prediction['scores'][i]
74
+ if score < score_threshold:
75
+ continue
76
+ if draw_boxes:
77
+ if only_print is not None and only_print != 'all':
78
+ if prediction['labels'][i] != list(class_dict.values()).index(only_print):
79
+ continue
80
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0), int(2*scale))
81
+ if write_score:
82
+ cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
83
+ if write_idx:
84
+ cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
85
+
86
+ if write_class and 'labels' in prediction:
87
+ class_id = prediction['labels'][i]
88
+ cv2.putText(image_copy, class_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
89
+
90
+
91
+ # Draw keypoints if available
92
+ if draw_keypoints and 'keypoints' in prediction:
93
+ for i in range(len(prediction['keypoints'])):
94
+ kp = prediction['keypoints'][i]
95
+ for j in range(kp.shape[0]):
96
+ if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
97
+ continue
98
+
99
+ score = prediction['scores'][i]
100
+ if score < score_threshold:
101
+ continue
102
+ x,y, v = np.array(kp[j])
103
+ x, y, v = resize_keypoints(np.array([kp[j]]), (new_scaled_size[1],new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
104
+ if j == 0:
105
+ cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
106
+ else:
107
+ cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
108
+
109
+ # Draw text predictions if available
110
+ if (draw_text or write_text) and text_predictions is not None:
111
+ for i in range(len(text_predictions[0])):
112
+ x1, y1, x2, y2 = text_predictions[0][i]
113
+ text = text_predictions[1][i]
114
+ if resize:
115
+ x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
116
+ if draw_text:
117
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
118
+ if write_text:
119
+ cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
120
+
121
+
122
+ '''Draws links between objects based on the full prediction data.'''
123
+ #check if keypoints detected are the same
124
+ if draw_twins and prediction is not None:
125
+ # Pre-calculate indices for performance
126
+ circle_color = (0, 255, 0) # Green color for the circle
127
+ circle_radius = int(10 * scale) # Circle radius scaled by image scale
128
+
129
+ for idx, (key1, key2) in enumerate(prediction['keypoints']):
130
+ if prediction['labels'][idx] not in [list(class_dict.values()).index('sequenceFlow'),
131
+ list(class_dict.values()).index('messageFlow'),
132
+ list(class_dict.values()).index('dataAssociation')]:
133
+ continue
134
+ # Calculate the Euclidean distance between the two keypoints
135
+ distance = np.linalg.norm(key1[:2] - key2[:2])
136
+ if distance < 10:
137
+ x_new,y_new, x,y = find_other_keypoint(idx,prediction)
138
+ cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
139
+ cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
140
+
141
+ # Draw links between objects
142
+ if draw_links==True and prediction is not None:
143
+ for i, (start_idx, end_idx) in enumerate(prediction['links']):
144
+ if start_idx is None or end_idx is None:
145
+ continue
146
+ start_box = prediction['boxes'][start_idx]
147
+ start_box = resize_boxes(np.array([start_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
148
+ end_box = prediction['boxes'][end_idx]
149
+ end_box = resize_boxes(np.array([end_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
150
+ current_box = prediction['boxes'][i]
151
+ current_box = resize_boxes(np.array([current_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
152
+ # Calculate the center of each bounding box
153
+ start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
154
+ end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
155
+ current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
156
+ # Draw a line between the centers of the connected objects
157
+ cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
158
+ cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
159
+
160
+
161
+ if draw_grouped_text and prediction is not None:
162
+ task_boxes = task_boxes = [box for i, box in enumerate(prediction['boxes']) if prediction['labels'][i] == list(class_dict.values()).index('task')]
163
+ grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_predictions[0], text_predictions[1], percentage_thresh=1)
164
+ for i in range(len(info_boxes)):
165
+ x1, y1, x2, y2 = info_boxes[i]
166
+ x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
167
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
168
+ for i in range(len(sentence_bounding_boxes)):
169
+ x1,y1,x2,y2 = sentence_bounding_boxes[i]
170
+ x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
171
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
172
+
173
+ if return_image:
174
+ return image_copy
175
+ else:
176
+ # Display the image
177
+ plt.figure(figsize=(12, 12))
178
+ plt.imshow(image_copy)
179
+ if axis==False:
180
+ plt.axis('off')
181
+ plt.show()
eval.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
4
+ from tqdm import tqdm
5
+ from toXML import create_BPMN_id
6
+
7
+
8
+
9
+
10
+ def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
11
+ idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order
12
+ selected_boxes = []
13
+
14
+ while len(idxs) > 0:
15
+ last = len(idxs) - 1
16
+ i = idxs[last]
17
+
18
+ # Skip if the label is a lane
19
+ if labels is not None and class_dict[labels[i]] == 'lane':
20
+ selected_boxes.append(i)
21
+ idxs = np.delete(idxs, last)
22
+ continue
23
+
24
+ selected_boxes.append(i)
25
+
26
+ # Find the intersection of the box with the rest
27
+ suppress = [last]
28
+ for pos in range(0, last):
29
+ j = idxs[pos]
30
+ if iou(boxes[i], boxes[j]) > iou_threshold:
31
+ suppress.append(pos)
32
+
33
+ idxs = np.delete(idxs, suppress)
34
+
35
+ # Return only the boxes that were selected
36
+ return selected_boxes
37
+
38
+
39
+ def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15):
40
+ for idx, (key1, key2) in enumerate(keypoints):
41
+ if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'),
42
+ list(model_dict.values()).index('messageFlow'),
43
+ list(model_dict.values()).index('dataAssociation')]:
44
+ continue
45
+ # Calculate the Euclidean distance between the two keypoints
46
+ distance = np.linalg.norm(key1[:2] - key2[:2])
47
+ if distance < distance_treshold:
48
+ print('Key modified for index:', idx)
49
+ x_new,y_new, x,y = find_other_keypoint(idx, keypoints, boxes)
50
+ keypoints[idx][0][:2] = [x_new,y_new]
51
+ keypoints[idx][1][:2] = [x,y]
52
+
53
+ return keypoints
54
+
55
+
56
+ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
57
+ model.eval()
58
+ with torch.no_grad():
59
+ image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
60
+ predictions = model(image_tensor)
61
+
62
+ boxes = predictions[0]['boxes'].cpu().numpy()
63
+ labels = predictions[0]['labels'].cpu().numpy()
64
+ scores = predictions[0]['scores'].cpu().numpy()
65
+
66
+ idx = np.where(scores > score_threshold)[0]
67
+ boxes = boxes[idx]
68
+ scores = scores[idx]
69
+ labels = labels[idx]
70
+
71
+ selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold)
72
+
73
+ #find orientation of the task by checking the size of all the boxes and delete the one that are not in the same orientation
74
+ vertical = 0
75
+ for i in range(len(labels)):
76
+ if labels[i] != list(object_dict.values()).index('task'):
77
+ continue
78
+ if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
79
+ vertical += 1
80
+ horizontal = len(labels) - vertical
81
+ for i in range(len(labels)):
82
+ if labels[i] != list(object_dict.values()).index('task'):
83
+ continue
84
+
85
+ if vertical < horizontal:
86
+ if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
87
+ #find the element in the list and remove it
88
+ if i in selected_boxes:
89
+ selected_boxes.remove(i)
90
+ elif vertical > horizontal:
91
+ if boxes[i][2]-boxes[i][0] > boxes[i][3]-boxes[i][1]:
92
+ #find the element in the list and remove it
93
+ if i in selected_boxes:
94
+ selected_boxes.remove(i)
95
+ else:
96
+ pass
97
+
98
+ boxes = boxes[selected_boxes]
99
+ scores = scores[selected_boxes]
100
+ labels = labels[selected_boxes]
101
+
102
+ prediction = {
103
+ 'boxes': boxes,
104
+ 'scores': scores,
105
+ 'labels': labels,
106
+ }
107
+
108
+ image = image.permute(1, 2, 0).cpu().numpy()
109
+ image = (image * 255).astype(np.uint8)
110
+
111
+ return image, prediction
112
+
113
+
114
+ def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15):
115
+ model.eval()
116
+ with torch.no_grad():
117
+ image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
118
+ predictions = model(image_tensor)
119
+
120
+ boxes = predictions[0]['boxes'].cpu().numpy()
121
+ labels = predictions[0]['labels'].cpu().numpy() + (len(object_dict) - 1)
122
+ scores = predictions[0]['scores'].cpu().numpy()
123
+ keypoints = predictions[0]['keypoints'].cpu().numpy()
124
+
125
+ idx = np.where(scores > score_threshold)[0]
126
+ boxes = boxes[idx]
127
+ scores = scores[idx]
128
+ labels = labels[idx]
129
+ keypoints = keypoints[idx]
130
+
131
+ selected_boxes = non_maximum_suppression(boxes, scores, iou_threshold=iou_threshold)
132
+ boxes = boxes[selected_boxes]
133
+ scores = scores[selected_boxes]
134
+ labels = labels[selected_boxes]
135
+ keypoints = keypoints[selected_boxes]
136
+
137
+ keypoints = keypoint_correction(keypoints, boxes, labels, class_dict, distance_treshold=distance_treshold)
138
+
139
+ prediction = {
140
+ 'boxes': boxes,
141
+ 'scores': scores,
142
+ 'labels': labels,
143
+ 'keypoints': keypoints,
144
+ }
145
+
146
+ image = image.permute(1, 2, 0).cpu().numpy()
147
+ image = (image * 255).astype(np.uint8)
148
+
149
+ return image, prediction
150
+
151
+ def mix_predictions(objects_pred, arrow_pred):
152
+ # Initialize the list of lists for keypoints
153
+ object_keypoints = []
154
+
155
+ # Number of boxes
156
+ num_boxes = len(objects_pred['boxes'])
157
+
158
+ # Iterate over the number of boxes
159
+ for _ in range(num_boxes):
160
+ # Each box has 2 keypoints, both initialized to [0, 0, 0]
161
+ keypoints = [[0, 0, 0], [0, 0, 0]]
162
+ object_keypoints.append(keypoints)
163
+
164
+ #concatenate the two predictions
165
+ boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes']))
166
+ labels = np.concatenate((objects_pred['labels'], arrow_pred['labels']))
167
+ scores = np.concatenate((objects_pred['scores'], arrow_pred['scores']))
168
+ keypoints = np.concatenate((object_keypoints, arrow_pred['keypoints']))
169
+
170
+ return boxes, labels, scores, keypoints
171
+
172
+ def regroup_elements_by_pool(boxes, labels, class_dict):
173
+ """
174
+ Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
175
+
176
+ Parameters:
177
+ - boxes (list): List of bounding boxes.
178
+ - labels (list): List of labels corresponding to each bounding box.
179
+ - class_dict (dict): Dictionary mapping class indices to class names.
180
+
181
+ Returns:
182
+ - dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
183
+ """
184
+ # Initialize a dictionary to hold the elements in each pool
185
+ pool_dict = {}
186
+
187
+ # Identify the bounding boxes of the pools
188
+ pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
189
+ pool_boxes = [boxes[i] for i in pool_indices]
190
+
191
+ if not pool_indices:
192
+ # If no pools or lanes are detected, create a single pool with all elements
193
+ labels = np.append(labels, list(class_dict.values()).index('pool'))
194
+ pool_dict[len(labels)-1] = list(range(len(boxes)))
195
+ else:
196
+ # Initialize each pool index with an empty list
197
+ for pool_index in pool_indices:
198
+ pool_dict[pool_index] = []
199
+
200
+ # Initialize a list for elements not in any pool
201
+ elements_not_in_pool = []
202
+
203
+ # Iterate over all elements
204
+ for i, box in enumerate(boxes):
205
+ if i in pool_indices or class_dict[labels[i]] == 'messageFlow':
206
+ continue # Skip pool boxes themselves and messageFlow elements
207
+ assigned_to_pool = False
208
+ for j, pool_box in enumerate(pool_boxes):
209
+ # Check if the element is within the pool's bounding box
210
+ if (box[0] >= pool_box[0] and box[1] >= pool_box[1] and
211
+ box[2] <= pool_box[2] and box[3] <= pool_box[3]):
212
+ pool_index = pool_indices[j]
213
+ pool_dict[pool_index].append(i)
214
+ assigned_to_pool = True
215
+ break
216
+ if not assigned_to_pool:
217
+ if class_dict[labels[i]] != 'messageFlow' and class_dict[labels[i]] != 'lane':
218
+ elements_not_in_pool.append(i)
219
+
220
+ if elements_not_in_pool:
221
+ new_pool_index = max(pool_dict.keys()) + 1
222
+ labels = np.append(labels, list(class_dict.values()).index('pool'))
223
+ pool_dict[new_pool_index] = elements_not_in_pool
224
+
225
+ # Separate empty pools
226
+ non_empty_pools = {k: v for k, v in pool_dict.items() if v}
227
+ empty_pools = {k: v for k, v in pool_dict.items() if not v}
228
+
229
+ # Merge non-empty pools followed by empty pools
230
+ pool_dict = {**non_empty_pools, **empty_pools}
231
+
232
+ return pool_dict, labels
233
+
234
+
235
+ def create_links(keypoints, boxes, labels, class_dict):
236
+ best_points = []
237
+ links = []
238
+ for i in range(len(labels)):
239
+ if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
240
+ closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
241
+ closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
242
+ if closest1 is not None and closest2 is not None:
243
+ best_points.append([point_start, point_end])
244
+ links.append([closest1, closest2])
245
+ else:
246
+ best_points.append([None,None])
247
+ links.append([None,None])
248
+
249
+ for i in range(len(labels)):
250
+ if labels[i]==list(class_dict.values()).index('dataAssociation'):
251
+ closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
252
+ closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
253
+ if closest1 is not None and closest2 is not None:
254
+ best_points[i] = ([point_start, point_end])
255
+ links[i] = ([closest1, closest2])
256
+
257
+ return links, best_points
258
+
259
+ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
260
+
261
+ for pool_index, elements in pool_dict.items():
262
+ print(f"Pool {pool_index} contains elements: {elements}")
263
+ #check if each link is in the same pool
264
+ for i in range(len(flow_links)):
265
+ if labels[i] == list(class_dict.values()).index('sequenceFlow'):
266
+ id1, id2 = flow_links[i]
267
+ if (id1 and id2) is not None:
268
+ if id1 in elements and id2 in elements:
269
+ continue
270
+ elif id1 not in elements and id2 not in elements:
271
+ continue
272
+ else:
273
+ print('change the link from sequenceFlow to messageFlow')
274
+ labels[i]=list(class_dict.values()).index('messageFlow')
275
+
276
+ return labels, flow_links
277
+
278
+
279
+ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_dict):
280
+
281
+ #delete pool that are have only messageFlow on it
282
+ delete_pool = []
283
+ for pool_index, elements in pool_dict.items():
284
+ if all([labels[i] == list(class_dict.values()).index('messageFlow') for i in elements]):
285
+ if len(elements) > 0:
286
+ delete_pool.append(pool_dict[pool_index])
287
+ print(f"Pool {pool_index} contains only messageFlow elements, deleting it")
288
+
289
+ #sort index
290
+ delete_pool = sorted(delete_pool, reverse=True)
291
+ for pool in delete_pool:
292
+ index = list(pool_dict.keys())[list(pool_dict.values()).index(pool)]
293
+ del pool_dict[index]
294
+
295
+
296
+ delete_elements = []
297
+ # Check if there is an arrow that has the same links
298
+ for i in range(len(labels)):
299
+ for j in range(i+1, len(labels)):
300
+ if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'):
301
+ if links[i] == links[j]:
302
+ print(f'element {i} and {j} have the same links')
303
+ if scores[i] > scores[j]:
304
+ print('delete element', j)
305
+ delete_elements.append(j)
306
+ else:
307
+ print('delete element', i)
308
+ delete_elements.append(i)
309
+
310
+ boxes = np.delete(boxes, delete_elements, axis=0)
311
+ labels = np.delete(labels, delete_elements)
312
+ scores = np.delete(scores, delete_elements)
313
+ keypoints = np.delete(keypoints, delete_elements, axis=0)
314
+ links = np.delete(links, delete_elements, axis=0)
315
+ best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
316
+
317
+ #also delete the element in the pool_dict
318
+ for pool_index, elements in pool_dict.items():
319
+ pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
320
+
321
+ return boxes, labels, scores, keypoints, links, best_points, pool_dict
322
+
323
+ def give_link_to_element(links, labels):
324
+ #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
325
+ for i in range(len(links)):
326
+ if labels[i] == list(class_dict.values()).index('sequenceFlow'):
327
+ id1, id2 = links[i]
328
+ if (id1 and id2) is not None:
329
+ links[id1][1] = i
330
+ links[id2][0] = i
331
+ return links
332
+
333
+ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
334
+ model_object.eval() # Set the model to evaluation mode
335
+ model_arrow.eval() # Set the model to evaluation mode
336
+
337
+ # Load an image
338
+ with torch.no_grad(): # Disable gradient calculation for inference
339
+ _, objects_pred = object_prediction(model_object, image, score_threshold=score_threshold, iou_threshold=iou_threshold)
340
+ _, arrow_pred = arrow_prediction(model_arrow, image, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
341
+
342
+ #print('Object prediction:', objects_pred)
343
+
344
+
345
+ boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
346
+
347
+ # Regroup elements by pool
348
+ pool_dict, labels = regroup_elements_by_pool(boxes,labels, class_dict)
349
+ # Create links between elements
350
+ flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
351
+ #Correct the labels of some sequenceflow that cross multiple pool
352
+ labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
353
+ #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
354
+ flow_links = give_link_to_element(flow_links, labels)
355
+
356
+ boxes,labels,scores,keypoints,flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,flow_links,best_points, pool_dict)
357
+
358
+ image = image.permute(1, 2, 0).cpu().numpy()
359
+ image = (image * 255).astype(np.uint8)
360
+ idx = []
361
+ for i in range(len(labels)):
362
+ idx.append(i)
363
+ bpmn_id = [class_dict[labels[i]] for i in range(len(labels))]
364
+
365
+ data = {
366
+ 'image': image,
367
+ 'idx': idx,
368
+ 'boxes': boxes,
369
+ 'labels': labels,
370
+ 'scores': scores,
371
+ 'keypoints': keypoints,
372
+ 'links': flow_links,
373
+ 'best_points': best_points,
374
+ 'pool_dict': pool_dict,
375
+ 'BPMN_id': bpmn_id,
376
+ }
377
+
378
+ # give a unique BPMN id to each element
379
+ data = create_BPMN_id(data)
380
+
381
+
382
+
383
+ return image, data
384
+
385
+ def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5):
386
+ # Initialize dictionaries to hold per-class counts
387
+ class_tp = {cls: 0 for cls in model_dict.values()}
388
+ class_fp = {cls: 0 for cls in model_dict.values()}
389
+ class_fn = {cls: 0 for cls in model_dict.values()}
390
+
391
+ # Track which true boxes have been matched
392
+ matched = [False] * len(true_boxes)
393
+
394
+ # Check each prediction against true boxes
395
+ for pred_box, pred_label in zip(pred_boxes, pred_labels):
396
+ match_found = False
397
+ for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
398
+ if not matched[idx] and pred_label == true_label:
399
+ if iou(np.array(pred_box), np.array(true_box)) >= iou_threshold:
400
+ class_tp[model_dict[pred_label]] += 1
401
+ matched[idx] = True
402
+ match_found = True
403
+ break
404
+ if not match_found:
405
+ class_fp[model_dict[pred_label]] += 1
406
+
407
+ # Count false negatives
408
+ for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
409
+ if not matched[idx]:
410
+ class_fn[model_dict[true_label]] += 1
411
+
412
+ # Calculate precision, recall, and F1-score per class
413
+ class_precision = {}
414
+ class_recall = {}
415
+ class_f1_score = {}
416
+
417
+ for cls in model_dict.values():
418
+ precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0
419
+ recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0
420
+ f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
421
+
422
+ class_precision[cls] = precision
423
+ class_recall[cls] = recall
424
+ class_f1_score[cls] = f1_score
425
+
426
+ return class_precision, class_recall, class_f1_score
427
+
428
+
429
+ def keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold=5):
430
+ result = 0
431
+ reverted = False
432
+ #find the position of keypoints in the list
433
+ idx = np.where(pred_boxes == pred_box)[0][0]
434
+ idx2 = np.where(true_boxes == true_box)[0][0]
435
+
436
+ keypoint1_pred = pred_keypoints[idx][0]
437
+ keypoint1_true = true_keypoints[idx2][0]
438
+ keypoint2_pred = pred_keypoints[idx][1]
439
+ keypoint2_true = true_keypoints[idx2][1]
440
+
441
+ distance1 = np.linalg.norm(keypoint1_pred[:2] - keypoint1_true[:2])
442
+ distance2 = np.linalg.norm(keypoint2_pred[:2] - keypoint2_true[:2])
443
+ distance3 = np.linalg.norm(keypoint1_pred[:2] - keypoint2_true[:2])
444
+ distance4 = np.linalg.norm(keypoint2_pred[:2] - keypoint1_true[:2])
445
+
446
+ if distance1 < distance_threshold:
447
+ result += 1
448
+ if distance2 < distance_threshold:
449
+ result += 1
450
+ if distance3 < distance_threshold or distance4 < distance_threshold:
451
+ reverted = True
452
+
453
+ return result, reverted
454
+
455
+ def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5):
456
+ tp, fp, fn = 0, 0, 0
457
+ key_t, key_f = 0, 0
458
+ labels_t, labels_f = 0, 0
459
+ reverted_tot = 0
460
+
461
+ matched_true_boxes = set()
462
+ for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
463
+ match_found = False
464
+ for true_idx, true_box in enumerate(true_boxes):
465
+ if true_idx in matched_true_boxes:
466
+ continue
467
+ iou_val = iou(pred_box, true_box)
468
+ if iou_val >= iou_threshold:
469
+ if true_keypoints is not None and pred_keypoints is not None:
470
+ key_result, reverted = keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold)
471
+ key_t += key_result
472
+ key_f += 2 - key_result
473
+ if reverted:
474
+ reverted_tot += 1
475
+
476
+ match_found = True
477
+ matched_true_boxes.add(true_idx)
478
+ if pred_label == true_labels[true_idx]:
479
+ labels_t += 1
480
+ else:
481
+ labels_f += 1
482
+ tp += 1
483
+ break
484
+ if not match_found:
485
+ fp += 1
486
+
487
+ fn = len(true_boxes) - tp
488
+
489
+ return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted_tot
490
+
491
+
492
+ def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
493
+ model.eval()
494
+ tp, fp, fn = 0, 0, 0
495
+ labels_t, labels_f = 0, 0
496
+ key_t, key_f = 0, 0
497
+ reverted = 0
498
+
499
+ with torch.no_grad():
500
+ for images, targets_im in tqdm(loader, desc="Testing... "): # Wrap the loader with tqdm
501
+ devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
502
+ images = [image.to(devices) for image in images]
503
+ targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im]
504
+
505
+ predictions = model(images)
506
+
507
+ for target, prediction in zip(targets, predictions):
508
+ true_boxes = target['boxes'].cpu().numpy()
509
+ true_labels = target['labels'].cpu().numpy()
510
+ if 'keypoints' in target:
511
+ true_keypoints = target['keypoints'].cpu().numpy()
512
+
513
+ pred_boxes = prediction['boxes'].cpu().numpy()
514
+ scores = prediction['scores'].cpu().numpy()
515
+ pred_labels = prediction['labels'].cpu().numpy()
516
+ if 'keypoints' in prediction:
517
+ pred_keypoints = prediction['keypoints'].cpu().numpy()
518
+
519
+ selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold)
520
+ pred_boxes = pred_boxes[selected_boxes]
521
+ scores = scores[selected_boxes]
522
+ pred_labels = pred_labels[selected_boxes]
523
+ if 'keypoints' in prediction:
524
+ pred_keypoints = pred_keypoints[selected_boxes]
525
+
526
+ filtered_boxes = []
527
+ filtered_labels = []
528
+ filtered_keypoints = []
529
+ if 'keypoints' not in prediction:
530
+ #create a list of zeros of length equal to the number of boxes
531
+ pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))]
532
+
533
+ for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints):
534
+ if score >= score_threshold:
535
+ filtered_boxes.append(box)
536
+ filtered_labels.append(label)
537
+ if 'keypoints' in prediction:
538
+ filtered_keypoints.append(keypoints)
539
+
540
+ if key_correction and ('keypoints' in prediction):
541
+ filtered_keypoints = keypoint_correction(filtered_keypoints, filtered_boxes, filtered_labels)
542
+
543
+ if 'keypoints' not in target:
544
+ filtered_keypoints = None
545
+ true_keypoints = None
546
+ tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image(
547
+ filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold)
548
+
549
+ tp += tp_img
550
+ fp += fp_img
551
+ fn += fn_img
552
+ labels_t += labels_t_img
553
+ labels_f += labels_f_img
554
+ key_t += key_t_img
555
+ key_f += key_f_img
556
+ reverted += reverted_img
557
+
558
+ return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted
559
+
560
+ def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type = 'object'):
561
+
562
+ tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted = pred_4_evaluation(model, test_loader, score_threshold, iou_threshold, distance_threshold, key_correction, model_type)
563
+
564
+ labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0
565
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
566
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
567
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
568
+ if model_type == 'arrow':
569
+ key_accuracy = key_t / (key_t + key_f) if (key_t + key_f) > 0 else 0
570
+ reverted_accuracy = reverted / (key_t + key_f) if (key_t + key_f) > 0 else 0
571
+ else:
572
+ key_accuracy = 0
573
+ reverted_accuracy = 0
574
+
575
+ return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy
576
+
577
+
578
+
579
+ def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5):
580
+ matched_true_boxes = set()
581
+ for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
582
+ match_found = False
583
+ for true_idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
584
+ if true_idx in matched_true_boxes:
585
+ continue
586
+ if pred_label == true_label and iou(np.array(pred_box), np.array(true_box)) >= iou_threshold:
587
+ class_tp[model_dict[pred_label]] += 1
588
+ matched_true_boxes.add(true_idx)
589
+ match_found = True
590
+ break
591
+ if not match_found:
592
+ class_fp[model_dict[pred_label]] += 1
593
+
594
+ for idx, true_label in enumerate(true_labels):
595
+ if idx not in matched_true_boxes:
596
+ class_fn[model_dict[true_label]] += 1
597
+
598
+ def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5):
599
+ model.eval()
600
+ with torch.no_grad():
601
+ for images, targets_im in tqdm(loader, desc="Testing... "):
602
+ devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
603
+ images = [image.to(devices) for image in images]
604
+ targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im]
605
+
606
+ predictions = model(images)
607
+
608
+ for target, prediction in zip(targets, predictions):
609
+ true_boxes = target['boxes'].cpu().numpy()
610
+ true_labels = target['labels'].cpu().numpy()
611
+
612
+ pred_boxes = prediction['boxes'].cpu().numpy()
613
+ scores = prediction['scores'].cpu().numpy()
614
+ pred_labels = prediction['labels'].cpu().numpy()
615
+
616
+ idx = np.where(scores > score_threshold)[0]
617
+ pred_boxes = pred_boxes[idx]
618
+ scores = scores[idx]
619
+ pred_labels = pred_labels[idx]
620
+
621
+ selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold)
622
+ pred_boxes = pred_boxes[selected_boxes]
623
+ scores = scores[selected_boxes]
624
+ pred_labels = pred_labels[selected_boxes]
625
+
626
+ yield pred_boxes, true_boxes, pred_labels, true_labels
627
+
628
+ def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5):
629
+ class_tp = {cls: 0 for cls in model_dict.values()}
630
+ class_fp = {cls: 0 for cls in model_dict.values()}
631
+ class_fn = {cls: 0 for cls in model_dict.values()}
632
+
633
+ for pred_boxes, true_boxes, pred_labels, true_labels in pred_4_evaluation_per_class(model, test_loader, score_threshold, iou_threshold):
634
+ evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold)
635
+
636
+ class_precision = {}
637
+ class_recall = {}
638
+ class_f1_score = {}
639
+
640
+ for cls in model_dict.values():
641
+ precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0
642
+ recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0
643
+ f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
644
+
645
+ class_precision[cls] = precision
646
+ class_recall[cls] = recall
647
+ class_f1_score[cls] = f1_score
648
+
649
+ return class_precision, class_recall, class_f1_score
flask.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from flask import Flask
2
+ app = Flask(__name__)
3
+
4
+ @app.route("/")
5
+ def hello():
6
+ return "Hello World!\n"
htlm_webpage.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+
4
+ def display_bpmn_xml(bpmn_xml):
5
+ html_template = f"""
6
+ <!DOCTYPE html>
7
+ <html>
8
+ <head>
9
+ <meta charset="UTF-8">
10
+ <title>BPMN Modeler</title>
11
+ <link rel="stylesheet" href="https://unpkg.com/bpmn-js/dist/assets/diagram-js.css">
12
+ <link rel="stylesheet" href="https://unpkg.com/bpmn-js/dist/assets/bpmn-font/css/bpmn-embedded.css">
13
+ <script src="https://unpkg.com/bpmn-js/dist/bpmn-modeler.development.js"></script>
14
+ <style>
15
+ html, body {{
16
+ height: 100%;
17
+ padding: 0;
18
+ margin: 0;
19
+ font-family: Arial, sans-serif;
20
+ display: flex;
21
+ flex-direction: column;
22
+ overflow: hidden;
23
+ }}
24
+ #button-container {{
25
+ padding: 10px;
26
+ background-color: #ffffff;
27
+ border-bottom: 1px solid #ddd;
28
+ display: flex;
29
+ justify-content: flex-start;
30
+ gap: 10px;
31
+ }}
32
+ #save-button, #download-button {{
33
+ background-color: #4CAF50;
34
+ color: white;
35
+ border: none;
36
+ padding: 10px 20px;
37
+ text-align: center;
38
+ text-decoration: none;
39
+ display: inline-block;
40
+ font-size: 16px;
41
+ margin: 4px 2px;
42
+ cursor: pointer;
43
+ border-radius: 8px;
44
+ }}
45
+ #download-button {{
46
+ background-color: #008CBA;
47
+ }}
48
+ #canvas-container {{
49
+ flex: 1;
50
+ position: relative;
51
+ background-color: #FBFBFB;
52
+ overflow: hidden; /* Prevent scrolling */
53
+ display: flex;
54
+ justify-content: center;
55
+ align-items: center;
56
+ }}
57
+ #canvas {{
58
+ height: 100%;
59
+ width: 100%;
60
+ position: relative;
61
+ }}
62
+ </style>
63
+ </head>
64
+ <body>
65
+ <div id="button-container">
66
+ <button id="save-button">Save as BPMN</button>
67
+ <button id="download-button">Save as XML</button>
68
+ <button id="download-button">Save as Vizi</button>
69
+ </div>
70
+ <div id="canvas-container">
71
+ <div id="canvas"></div>
72
+ </div>
73
+ <script>
74
+ var bpmnModeler = new BpmnJS({{
75
+ container: '#canvas'
76
+ }});
77
+
78
+ async function openDiagram(bpmnXML) {{
79
+ try {{
80
+ await bpmnModeler.importXML(bpmnXML);
81
+ bpmnModeler.get('canvas').zoom('fit-viewport');
82
+ bpmnModeler.get('canvas').zoom(0.8); // Adjust this value for zooming out
83
+ }} catch (err) {{
84
+ console.error('Error rendering BPMN diagram', err);
85
+ }}
86
+ }}
87
+
88
+ async function saveDiagram() {{
89
+ try {{
90
+ const result = await bpmnModeler.saveXML({{ format: true }});
91
+ const xml = result.xml;
92
+ const blob = new Blob([xml], {{ type: 'text/xml' }});
93
+ const url = URL.createObjectURL(blob);
94
+ const a = document.createElement('a');
95
+ a.href = url;
96
+ a.download = 'diagram.bpmn';
97
+ document.body.appendChild(a);
98
+ a.click();
99
+ document.body.removeChild(a);
100
+ }} catch (err) {{
101
+ console.error('Error saving BPMN diagram', err);
102
+ }}
103
+ }}
104
+
105
+ async function downloadXML() {{
106
+ const xml = `{bpmn_xml}`;
107
+ const blob = new Blob([xml], {{ type: 'text/xml' }});
108
+ const url = URL.createObjectURL(blob);
109
+ const a = document.createElement('a');
110
+ a.href = url;
111
+ a.download = 'diagram.xml';
112
+ document.body.appendChild(a);
113
+ a.click();
114
+ document.body.removeChild(a);
115
+ }}
116
+
117
+ document.getElementById('save-button').addEventListener('click', saveDiagram);
118
+ document.getElementById('download-button').addEventListener('click', downloadXML);
119
+
120
+ // Ensure the canvas is focused to capture keyboard events
121
+ document.getElementById('canvas').focus();
122
+
123
+ // Add event listeners for keyboard shortcuts
124
+ document.addEventListener('keydown', function(event) {{
125
+ if (event.ctrlKey && event.key === 'z') {{
126
+ bpmnModeler.get('commandStack').undo();
127
+ }} else if (event.key === 'Delete' || event.key === 'Backspace') {{
128
+ bpmnModeler.get('selection').get().forEach(function(element) {{
129
+ bpmnModeler.get('modeling').removeElements([element]);
130
+ }});
131
+ }}
132
+ }});
133
+
134
+ openDiagram(`{bpmn_xml}`);
135
+ </script>
136
+ </body>
137
+ </html>
138
+ """
139
+
140
+
141
+ components.html(html_template, height=1000, width=1500)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libgl1-mesa-glx
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ yamlu==0.0.17
2
+ tqdm==4.66.4
3
+ torchvision==0.18.0
4
+ azure-ai-vision-imageanalysis==1.0.0b2
5
+ streamlit==1.35.0
6
+ streamlit-image-comparison==0.0.4
7
+ streamlit-cropper==0.2.2
8
+ streamlit-drawable-canvas==0.9.3
9
+ opencv-python
10
+ gdown
toXML.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xml.etree.ElementTree as ET
2
+ from utils import class_dict
3
+
4
+ def rescale(scale, boxes):
5
+ for i in range(len(boxes)):
6
+ boxes[i] = [boxes[i][0]*scale,
7
+ boxes[i][1]*scale,
8
+ boxes[i][2]*scale,
9
+ boxes[i][3]*scale]
10
+ return boxes
11
+
12
+ def create_BPMN_id(data):
13
+ enums = {
14
+ 'end_event': 1,
15
+ 'start_event': 1,
16
+ 'task': 1,
17
+ 'sequenceFlow': 1,
18
+ 'messageFlow': 1,
19
+ 'message_event': 1,
20
+ 'exclusiveGateway': 1,
21
+ 'parallelGateway': 1,
22
+ 'dataAssociation': 1,
23
+ 'pool': 1,
24
+ 'dataObject': 1,
25
+ 'timerEvent': 1
26
+ }
27
+
28
+ BPMN_name = [class_dict[label] for label in data['labels']]
29
+
30
+ for idx, Bpmn_id in enumerate(BPMN_name):
31
+ if Bpmn_id == 'event':
32
+ if data['links'][idx][0] is not None and data['links'][idx][1] is None:
33
+ key = 'end_event'
34
+ elif data['links'][idx][0] is None and data['links'][idx][1] is not None:
35
+ key = 'start_event'
36
+ else:
37
+ key = {
38
+ 'task': 'task',
39
+ 'dataObject': 'dataObject',
40
+ 'sequenceFlow': 'sequenceFlow',
41
+ 'messageFlow': 'messageFlow',
42
+ 'messageEvent': 'message_event',
43
+ 'exclusiveGateway': 'exclusiveGateway',
44
+ 'parallelGateway': 'parallelGateway',
45
+ 'dataAssociation': 'dataAssociation',
46
+ 'pool': 'pool',
47
+ 'timerEvent': 'timerEvent'
48
+ }.get(Bpmn_id, None)
49
+
50
+ if key:
51
+ data['BPMN_id'][idx] = f'{key}_{enums[key]}'
52
+ enums[key] += 1
53
+
54
+ return data
55
+
56
+
57
+
58
+ def add_diagram_elements(parent, element_id, x, y, width, height):
59
+ """Utility to add BPMN diagram notation for elements."""
60
+ shape = ET.SubElement(parent, 'bpmndi:BPMNShape', attrib={
61
+ 'bpmnElement': element_id,
62
+ 'id': element_id + '_di'
63
+ })
64
+ bounds = ET.SubElement(shape, 'dc:Bounds', attrib={
65
+ 'x': str(x),
66
+ 'y': str(y),
67
+ 'width': str(width),
68
+ 'height': str(height)
69
+ })
70
+
71
+ def add_diagram_edge(parent, element_id, waypoints):
72
+ """Utility to add BPMN diagram notation for sequence flows."""
73
+ edge = ET.SubElement(parent, 'bpmndi:BPMNEdge', attrib={
74
+ 'bpmnElement': element_id,
75
+ 'id': element_id + '_di'
76
+ })
77
+ for x, y in waypoints:
78
+ ET.SubElement(edge, 'di:waypoint', attrib={
79
+ 'x': str(x),
80
+ 'y': str(y)
81
+ })
82
+
83
+
84
+ def check_status(link, keep_elements):
85
+ if link[0] in keep_elements and link[1] in keep_elements:
86
+ return 'middle'
87
+ elif link[0] is None and link[1] in keep_elements:
88
+ return 'start'
89
+ elif link[0] in keep_elements and link[1] is None:
90
+ return 'end'
91
+ else:
92
+ return 'middle'
93
+
94
+ def check_data_association(i, links, labels, keep_elements):
95
+ for j, (k,l) in enumerate(links):
96
+ if labels[j] == 14:
97
+ if k==i:
98
+ return 'output',j
99
+ elif l==i:
100
+ return 'input',j
101
+
102
+ return 'no association', None
103
+
104
+ def create_data_Association(bpmn,data,size,element_id,source_id,target_id):
105
+ waypoints = calculate_waypoints(data, size, source_id, target_id)
106
+ add_diagram_edge(bpmn, element_id, waypoints)
107
+
108
+ # Function to dynamically create and layout BPMN elements
109
+ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
110
+ elements = data['BPMN_id']
111
+ positions = data['boxes']
112
+ links = data['links']
113
+
114
+ for i in keep_elements:
115
+ element_id = elements[i]
116
+ if element_id is None:
117
+ continue
118
+
119
+ element_type = element_id.split('_')[0]
120
+ x, y = positions[i][:2]
121
+
122
+ # Start Event
123
+ if element_type == 'start':
124
+ element = ET.SubElement(process, 'bpmn:startEvent', id=element_id, name=text_mapping[element_id])
125
+ add_diagram_elements(bpmnplane, element_id, x, y, size['start'][0], size['start'][1])
126
+
127
+ # Task
128
+ elif element_type == 'task':
129
+ element = ET.SubElement(process, 'bpmn:task', id=element_id, name=text_mapping[element_id])
130
+ status, dataAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
131
+
132
+ # Handle Data Input Association
133
+ if status == 'input':
134
+ dataObject_idx = links[dataAssociation_idx][0]
135
+ dataObject_name = elements[dataObject_idx]
136
+ dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
137
+ sub_element = ET.SubElement(element, 'bpmn:dataInputAssociation', id=f'dataInputAssociation_{dataObject_ref.split("_")[1]}')
138
+ ET.SubElement(sub_element, 'bpmn:sourceRef').text = dataObject_ref
139
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataObject_name, element_id)
140
+
141
+ # Handle Data Output Association
142
+ elif status == 'output':
143
+ dataObject_idx = links[dataAssociation_idx][1]
144
+ dataObject_name = elements[dataObject_idx]
145
+ dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
146
+ sub_element = ET.SubElement(element, 'bpmn:dataOutputAssociation', id=f'dataOutputAssociation_{dataObject_ref.split("_")[1]}')
147
+ ET.SubElement(sub_element, 'bpmn:targetRef').text = dataObject_ref
148
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], element_id, dataObject_name)
149
+
150
+ add_diagram_elements(bpmnplane, element_id, x, y, size['task'][0], size['task'][1])
151
+
152
+ # Message Events (Start, Intermediate, End)
153
+ elif element_type == 'message':
154
+ status = check_status(links[i], keep_elements)
155
+ if status == 'start':
156
+ element = ET.SubElement(process, 'bpmn:startEvent', id=element_id, name=text_mapping[element_id])
157
+ elif status == 'middle':
158
+ element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
159
+ elif status == 'end':
160
+ element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
161
+ ET.SubElement(element, 'bpmn:messageEventDefinition', id=f'MessageEventDefinition_{i+1}')
162
+ add_diagram_elements(bpmnplane, element_id, x, y, size['message'][0], size['message'][1])
163
+
164
+ # End Event
165
+ elif element_type == 'end':
166
+ element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
167
+ add_diagram_elements(bpmnplane, element_id, x, y, size['end'][0], size['end'][1])
168
+
169
+ # Gateways (Exclusive, Parallel)
170
+ elif element_type in ['exclusiveGateway', 'parallelGateway']:
171
+ gateway_type = 'exclusiveGateway' if element_type == 'exclusiveGateway' else 'parallelGateway'
172
+ element = ET.SubElement(process, f'bpmn:{gateway_type}', id=element_id)
173
+ add_diagram_elements(bpmnplane, element_id, x, y, size[element_type][0], size[element_type][1])
174
+
175
+ # Data Object
176
+ elif element_type == 'dataObject':
177
+ dataObject_idx = element_id.split('_')[1]
178
+ dataObject_ref = f'DataObjectReference_{dataObject_idx}'
179
+ element = ET.SubElement(process, 'bpmn:dataObjectReference', id=dataObject_ref, dataObjectRef=element_id, name=text_mapping[element_id])
180
+ ET.SubElement(process, 'bpmn:dataObject', id=element_id)
181
+ add_diagram_elements(bpmnplane, dataObject_ref, x, y, size['dataObject'][0], size['dataObject'][1])
182
+
183
+ # Timer Event
184
+ elif element_type == 'timerEvent':
185
+ element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
186
+ ET.SubElement(element, 'bpmn:timerEventDefinition', id=f'TimerEventDefinition_{i+1}')
187
+ add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
188
+
189
+
190
+
191
+ # Calculate simple waypoints between two elements (this function assumes direct horizontal links for simplicity)
192
+ def calculate_waypoints(data, size, source_id, target_id):
193
+ source_idx = data['BPMN_id'].index(source_id)
194
+ target_idx = data['BPMN_id'].index(target_id)
195
+ name_source = source_id.split('_')[0]
196
+ name_target = target_id.split('_')[0]
197
+
198
+ #Get the position of the source and target
199
+ source_x, source_y = data['boxes'][source_idx][:2]
200
+ target_x, target_y = data['boxes'][target_idx][:2]
201
+
202
+ # Calculate relative position between source and target from their centers
203
+ relative_x = (target_x+size[name_target][0])/2 - (source_x+size[name_source][0])/2
204
+ relative_y = (target_y+size[name_target][1])/2 - (source_y+size[name_source][1])/2
205
+
206
+ # Get the size of the elements
207
+ size_x_source = size[name_source][0]
208
+ size_y_source = size[name_source][1]
209
+ size_x_target = size[name_target][0]
210
+ size_y_target = size[name_target][1]
211
+
212
+ #if it going to right
213
+ if relative_x >= size[name_source][0]:
214
+ source_x += size_x_source
215
+ source_y += size_y_source / 2
216
+ target_x = target_x
217
+ target_y += size_y_target / 2
218
+ #if the source is going up
219
+ if relative_y < -size[name_source][1]:
220
+ source_x -= size_x_source / 2
221
+ source_y -= size_y_source / 2
222
+ #if the source is going down
223
+ elif relative_y > size[name_source][1]:
224
+ source_x -= size_x_source / 2
225
+ source_y += size_y_source / 2
226
+ #if it going to left
227
+ elif relative_x < -size[name_source][0]:
228
+ source_x = source_x
229
+ source_y += size_y_source / 2
230
+ target_x += size_x_target
231
+ target_y += size_y_target / 2
232
+ #if the source is going up
233
+ if relative_y < -size[name_source][1]:
234
+ source_x += size_x_source / 2
235
+ source_y -= size_y_source / 2
236
+ #if the source is going down
237
+ elif relative_y > size[name_source][1]:
238
+ source_x += size_x_source / 2
239
+ source_y += size_y_source / 2
240
+ #if it going up and down
241
+ elif -size[name_source][0] < relative_x < size[name_source][0]:
242
+ source_x += size_x_source / 2
243
+ target_x += size_x_target / 2
244
+ #if it's going down
245
+ if relative_y >= size[name_source][1]/2:
246
+ source_y += size_y_source
247
+ #if it's going up
248
+ elif relative_y < -size[name_source][1]/2:
249
+ source_y = source_y
250
+ target_y += size_y_target
251
+ else:
252
+ if relative_x >= 0:
253
+ source_x += size_x_source/2
254
+ source_y += size_y_source/2
255
+ target_x -= size_x_target/2
256
+ target_y += size_y_target/2
257
+ else:
258
+ source_x -= size_x_source/2
259
+ source_y += size_y_source/2
260
+ target_x += size_x_target/2
261
+ target_y += size_y_target/2
262
+
263
+ return [(source_x, source_y), (target_x, target_y)]
264
+
265
+
266
+ def calculate_pool_bounds(data, keep_elements, size):
267
+ min_x = min_y = float('10000')
268
+ max_x = max_y = float('0')
269
+
270
+ for i in keep_elements:
271
+ if i >= len(data['BPMN_id']):
272
+ print("Problem with the index")
273
+ continue
274
+ element = data['BPMN_id'][i]
275
+ if element is None or data['labels'][i] == 13 or data['labels'][i] == 14 or data['labels'][i] == 15 or data['labels'][i] == 7 or data['labels'][i] == 15:
276
+ continue
277
+
278
+ element_type = element.split('_')[0]
279
+ x, y = data['boxes'][i][:2]
280
+ element_width, element_height = size[element_type]
281
+
282
+ min_x = min(min_x, x)
283
+ min_y = min(min_y, y)
284
+ max_x = max(max_x, x + element_width)
285
+ max_y = max(max_y, y + element_height)
286
+
287
+ return min_x, min_y, max_x, max_y
288
+
289
+
290
+ def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
291
+ # Get the bounding boxes of the source and target elements
292
+ source_box = data['boxes'][source_idx]
293
+ target_box = data['boxes'][target_idx]
294
+
295
+ # Get the midpoints of the source element
296
+ source_mid_x = (source_box[0] + source_box[2]) / 2
297
+ source_mid_y = (source_box[1] + source_box[3]) / 2
298
+
299
+ # Check if the connection involves a pool
300
+ if source_element == 'pool':
301
+ pool_box = source_box
302
+ element_box = (target_box[0], target_box[1], target_box[0]+size[target_element][0], target_box[1]+size[target_element][1])
303
+ element_mid_x = (element_box[0] + element_box[2]) / 2
304
+ element_mid_y = (element_box[1] + element_box[3]) / 2
305
+ # Connect the pool's bottom or top side to the target element's top or bottom center
306
+ if pool_box[3] < element_box[1]: # Pool is above the target element
307
+ waypoints = [(element_mid_x, pool_box[3]-50), (element_mid_x, element_box[1])]
308
+ else: # Pool is below the target element
309
+ waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]-50)]
310
+ else:
311
+ pool_box = target_box
312
+ element_box = (source_box[0], source_box[1], source_box[0]+size[source_element][0], source_box[1]+size[source_element][1])
313
+ element_mid_x = (element_box[0] + element_box[2]) / 2
314
+ element_mid_y = (element_box[1] + element_box[3]) / 2
315
+
316
+ # Connect the element's bottom or top center to the pool's top or bottom side
317
+ if pool_box[3] < element_box[1]: # Pool is above the target element
318
+ waypoints = [(element_mid_x, element_box[1]), (element_mid_x, pool_box[3]-50)]
319
+ else: # Pool is below the target element
320
+ waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]-50)]
321
+
322
+ return waypoints
323
+
324
+
325
+
326
+ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
327
+ source_idx, target_idx = data['links'][idx]
328
+ source_id, target_id = data['BPMN_id'][source_idx], data['BPMN_id'][target_idx]
329
+ if message:
330
+ element_id = f'messageflow_{source_id}_{target_id}'
331
+ else:
332
+ element_id = f'sequenceflow_{source_id}_{target_id}'
333
+
334
+ if source_id.split('_')[0] == 'pool' or target_id.split('_')[0] == 'pool':
335
+ waypoints = calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_id.split('_')[0], target_id.split('_')[0])
336
+ #waypoints = data['best_points'][idx]
337
+ if source_id.split('_')[0] == 'pool':
338
+ source_id = f"participant_{source_id.split('_')[1]}"
339
+ if target_id.split('_')[0] == 'pool':
340
+ target_id = f"participant_{target_id.split('_')[1]}"
341
+ else:
342
+ waypoints = calculate_waypoints(data, size, source_id, target_id)
343
+ #waypoints = data['best_points'][idx]
344
+
345
+ #waypoints = data['best_points'][idx]
346
+ if message:
347
+ element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
348
+ else:
349
+ element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
350
+ add_diagram_edge(bpmn, element_id, waypoints)
351
+
train.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import cv2
3
+ import numpy as np
4
+ import random
5
+ import time
6
+ import torch
7
+ import torchvision.transforms.functional as F
8
+ import matplotlib.pyplot as plt
9
+
10
+ from eval import main_evaluation
11
+ from torch.optim import SGD, AdamW
12
+ from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset
13
+ from torch.utils.data.dataloader import default_collate
14
+ from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
15
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
16
+ from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
17
+ from tqdm import tqdm
18
+ from utils import write_results
19
+
20
+
21
+
22
+
23
+ def get_arrow_model(num_classes, num_keypoints=2):
24
+ """
25
+ Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
26
+
27
+ Parameters:
28
+ - num_classes (int): Number of classes for the model to detect, excluding the background class.
29
+ - num_keypoints (int): Number of keypoints to predict for each detected object.
30
+
31
+ Returns:
32
+ - model (torch.nn.Module): The modified Keypoint R-CNN model.
33
+
34
+ Steps:
35
+ 1. Load a pre-trained Keypoint R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN).
36
+ The model is initially configured for the COCO dataset, which includes various object classes and keypoints.
37
+ 2. Replace the box predictor to adjust the number of output classes. The box predictor is responsible for
38
+ classifying detected regions and predicting their bounding boxes.
39
+ 3. Replace the keypoint predictor to adjust the number of keypoints the model predicts for each object.
40
+ This is necessary to tailor the model to specific tasks that may have different keypoint structures.
41
+ """
42
+ # Load a model pre-trained on COCO, initialized without pre-trained weights
43
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
44
+ if device == torch.device('cuda'):
45
+ model = keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1)
46
+ else:
47
+ model = keypointrcnn_resnet50_fpn(weights=False)
48
+
49
+ # Get the number of input features for the classifier in the box predictor.
50
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
51
+
52
+ # Replace the box predictor in the ROI heads with a new one, tailored to the number of classes.
53
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
54
+
55
+ # Replace the keypoint predictor in the ROI heads with a new one, specifically designed for the desired number of keypoints.
56
+ model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512, num_keypoints)
57
+
58
+ return model
59
+
60
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
61
+ from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
62
+ def get_faster_rcnn_model(num_classes):
63
+ """
64
+ Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes.
65
+
66
+ Parameters:
67
+ - num_classes (int): Number of classes for the model to detect, including the background class.
68
+
69
+ Returns:
70
+ - model (torch.nn.Module): The modified Faster R-CNN model.
71
+ """
72
+ # Load a pre-trained Faster R-CNN model
73
+ model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
74
+
75
+ # Get the number of input features for the classifier in the box predictor
76
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
77
+
78
+ # Replace the box predictor with a new one, tailored to the number of classes (num_classes includes the background)
79
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
80
+
81
+ return model
82
+
83
+ def prepare_model(dict,opti,learning_rate= 0.0003,model_to_load=None, model_type = 'object'):
84
+ # Adjusted to pass the class_dict directly
85
+ if model_type == 'object':
86
+ model = get_faster_rcnn_model(len(dict))
87
+ elif model_type == 'arrow':
88
+ model = get_arrow_model(len(dict),2)
89
+
90
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
91
+ # Load the model weights
92
+ if model_to_load:
93
+ model.load_state_dict(torch.load('./models/'+ model_to_load +'.pth', map_location=device))
94
+ print(f"Model '{model_to_load}' loaded")
95
+
96
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
97
+ model.to(device)
98
+
99
+ if opti == 'SGD':
100
+ #learning_rate= 0.002
101
+ optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
102
+ elif opti == 'Adam':
103
+ #learning_rate = 0.0003
104
+ optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999))
105
+ else:
106
+ print('Optimizer not found')
107
+
108
+ return model, optimizer, device
109
+
110
+
111
+
112
+
113
+ def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
114
+ model.train() # Set the model to evaluation mode
115
+ total_loss = 0
116
+
117
+ # Initialize lists to keep track of individual losses
118
+ loss_classifier_list = []
119
+ loss_box_reg_list = []
120
+ loss_objectness_list = []
121
+ loss_rpn_box_reg_list = []
122
+ loss_keypoints_list = []
123
+
124
+ with torch.no_grad(): # Disable gradient computation
125
+ for images, targets_im in tqdm(data_loader, desc="Evaluating"):
126
+ images = [image.to(device) for image in images]
127
+ targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
128
+
129
+ loss_dict = model(images, targets)
130
+
131
+ # Calculate the total loss for the current batch
132
+ losses = 0
133
+ if loss_config is not None:
134
+ for key, loss in loss_dict.items():
135
+ if loss_config.get(key, False):
136
+ losses += loss
137
+ else:
138
+ losses = sum(loss for key, loss in loss_dict.items())
139
+
140
+ total_loss += losses.item()
141
+
142
+ # Collect individual losses
143
+ if loss_dict.get('loss_classifier') is not None:
144
+ loss_classifier_list.append(loss_dict['loss_classifier'].item())
145
+ else:
146
+ loss_classifier_list.append(0)
147
+
148
+ if loss_dict.get('loss_box_reg') is not None:
149
+ loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
150
+ else:
151
+ loss_box_reg_list.append(0)
152
+
153
+ if loss_dict.get('loss_objectness') is not None:
154
+ loss_objectness_list.append(loss_dict['loss_objectness'].item())
155
+ else:
156
+ loss_objectness_list.append(0)
157
+
158
+ if loss_dict.get('loss_rpn_box_reg') is not None:
159
+ loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
160
+ else:
161
+ loss_rpn_box_reg_list.append(0)
162
+
163
+ if 'loss_keypoint' in loss_dict:
164
+ loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
165
+ else:
166
+ loss_keypoints_list.append(0)
167
+
168
+ # Calculate average loss
169
+ avg_loss = total_loss / len(data_loader)
170
+
171
+ avg_loss_classifier = np.mean(loss_classifier_list)
172
+ avg_loss_box_reg = np.mean(loss_box_reg_list)
173
+ avg_loss_objectness = np.mean(loss_objectness_list)
174
+ avg_loss_rpn_box_reg = np.mean(loss_rpn_box_reg_list)
175
+ avg_loss_keypoints = np.mean(loss_keypoints_list)
176
+
177
+ if print_losses:
178
+ print(f"Average Loss: {avg_loss:.4f}")
179
+ print(f"Average Classifier Loss: {avg_loss_classifier:.4f}")
180
+ print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}")
181
+ print(f"Average Objectness Loss: {avg_loss_objectness:.4f}")
182
+ print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}")
183
+ print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}")
184
+
185
+ return avg_loss
186
+
187
+
188
+ def training_model(num_epochs, model, data_loader, subset_test_loader,
189
+ optimizer, model_to_load=None, change_learning_rate=5, start_key=30,
190
+ batch_size=4, crop_prob=0.2, h_flip_prob=0.3, v_flip_prob=0.3,
191
+ max_rotate_deg=20, rotate_proba=0.2, blur_prob=0.2,
192
+ score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
193
+ information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
194
+ eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
195
+
196
+
197
+ if loss_config is None:
198
+ print('No loss config found, all losses will be used.')
199
+ else:
200
+ #print the list of the losses that will be used
201
+ print('The following losses will be used: ', end='')
202
+ for key, value in loss_config.items():
203
+ if value:
204
+ print(key, end=", ")
205
+ print()
206
+
207
+
208
+ # Initialize lists to store epoch-wise average losses
209
+ epoch_avg_losses = []
210
+ epoch_avg_loss_classifier = []
211
+ epoch_avg_loss_box_reg = []
212
+ epoch_avg_loss_objectness = []
213
+ epoch_avg_loss_rpn_box_reg = []
214
+ epoch_avg_loss_keypoints = []
215
+ epoch_precision = []
216
+ epoch_recall = []
217
+ epoch_f1_score = []
218
+ epoch_test_loss = []
219
+
220
+
221
+ start_tot = time.time()
222
+ best_metrics = -1000
223
+ best_epoch = 0
224
+ best_model_state = None
225
+ same = 0
226
+ learning_rate = optimizer.param_groups[0]['lr']
227
+ bad_test_loss = 0
228
+ previous_test_loss = 1000
229
+
230
+ print(f"Let's go training {model_type} model with {num_epochs} epochs!")
231
+ print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, Flip prob: {h_flip_prob}, Rotate prob: {rotate_proba}, Blur prob: {blur_prob}")
232
+
233
+ for epoch in range(num_epochs):
234
+
235
+ if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>1:
236
+ learning_rate = 0.7*learning_rate
237
+ optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
238
+ print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
239
+ bad_test_loss = 0
240
+ if epoch>0 and (epoch)==start_key:
241
+ print("Now it's training Keypoints also")
242
+ loss_config['loss_keypoint'] = True
243
+ for name, param in model.named_parameters():
244
+ if 'keypoint' in name:
245
+ param.requires_grad = True
246
+
247
+ model.train()
248
+ start = time.time()
249
+ total_loss = 0
250
+
251
+ # Initialize lists to keep track of individual losses
252
+ loss_classifier_list = []
253
+ loss_box_reg_list = []
254
+ loss_objectness_list = []
255
+ loss_rpn_box_reg_list = []
256
+ loss_keypoints_list = []
257
+
258
+ # Create a tqdm progress bar
259
+ progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}')
260
+
261
+ for images, targets_im in progress_bar:
262
+ images = [image.to(device) for image in images]
263
+ targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
264
+
265
+ optimizer.zero_grad()
266
+
267
+ loss_dict = model(images, targets)
268
+ # Inside the training loop where losses are calculated:
269
+ losses = 0
270
+ if loss_config is not None:
271
+ for key, loss in loss_dict.items():
272
+ if loss_config.get(key, False):
273
+ if key == 'loss_classifier':
274
+ loss *= 3
275
+ losses += loss
276
+ else:
277
+ losses = sum(loss for key, loss in loss_dict.items())
278
+
279
+ # Collect individual losses
280
+ if loss_dict['loss_classifier']:
281
+ loss_classifier_list.append(loss_dict['loss_classifier'].item())
282
+ else:
283
+ loss_classifier_list.append(0)
284
+
285
+ if loss_dict['loss_box_reg']:
286
+ loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
287
+ else:
288
+ loss_box_reg_list.append(0)
289
+
290
+ if loss_dict['loss_objectness']:
291
+ loss_objectness_list.append(loss_dict['loss_objectness'].item())
292
+ else:
293
+ loss_objectness_list.append(0)
294
+
295
+ if loss_dict['loss_rpn_box_reg']:
296
+ loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
297
+ else:
298
+ loss_rpn_box_reg_list.append(0)
299
+
300
+ if 'loss_keypoint' in loss_dict:
301
+ loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
302
+ else:
303
+ loss_keypoints_list.append(0)
304
+
305
+
306
+ losses.backward()
307
+ optimizer.step()
308
+
309
+ total_loss += losses.item()
310
+
311
+ # Update the description with the current loss
312
+ progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}')
313
+
314
+ # Calculate average loss
315
+ avg_loss = total_loss / len(data_loader)
316
+
317
+ epoch_avg_losses.append(avg_loss)
318
+ epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
319
+ epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
320
+ epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
321
+ epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
322
+ epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
323
+
324
+
325
+ # Evaluate the model on the test set
326
+ if eval_metric != 'loss':
327
+ avg_test_loss = 0
328
+ labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
329
+ print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
330
+ if eval_metric == 'all':
331
+ avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
332
+ print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
333
+ if eval_metric == 'loss':
334
+ labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
335
+ avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
336
+ print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
337
+
338
+ print(f"Time: {time.time() - start:.2f} [s]")
339
+
340
+
341
+ if epoch>0 and (epoch)%start_key == 0:
342
+ print(f"Keypoints Accuracy: {key_accuracy:.4f}", end=", ")
343
+
344
+ if eval_metric == 'f1_score':
345
+ metric_used = f1_score
346
+ elif eval_metric == 'precision':
347
+ metric_used = precision
348
+ elif eval_metric == 'recall':
349
+ metric_used = recall
350
+ else:
351
+ metric_used = -avg_test_loss
352
+
353
+ # Check if this epoch's model has the lowest average loss
354
+ if metric_used > best_metrics:
355
+ best_metrics = metric_used
356
+ best_epoch = epoch+1+start_epoch
357
+ best_model_state = copy.deepcopy(model.state_dict())
358
+
359
+ if epoch>0 and f1_score>early_stop_f1_score:
360
+ same+=1
361
+
362
+ epoch_precision.append(precision)
363
+ epoch_recall.append(recall)
364
+ epoch_f1_score.append(f1_score)
365
+ epoch_test_loss.append(avg_test_loss)
366
+
367
+ name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
368
+
369
+ if same >=1 :
370
+ metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
371
+ torch.save(best_model_state, './models/'+ name_model +'.pth')
372
+ write_results(name_model,metrics_list,start_epoch)
373
+ break
374
+
375
+ if (epoch+1+start_epoch) % 5 == 0:
376
+ metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
377
+ torch.save(best_model_state, './models/'+ name_model +'.pth')
378
+ model.load_state_dict(best_model_state)
379
+ write_results(name_model,metrics_list,start_epoch)
380
+
381
+ if avg_test_loss > previous_test_loss:
382
+ bad_test_loss += 1
383
+ previous_test_loss = avg_test_loss
384
+
385
+
386
+ print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an f1_score of {best_metrics:.4f}")
387
+ if best_model_state:
388
+ metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
389
+ torch.save(best_model_state, './models/'+ name_model +'.pth')
390
+ model.load_state_dict(best_model_state)
391
+ write_results(name_model,metrics_list,start_epoch)
392
+ print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
393
+
394
+ return model, metrics_list
utils.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.models.detection import keypointrcnn_resnet50_fpn
2
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
3
+ from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
4
+ from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
5
+ import random
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import torchvision.transforms.functional as F
9
+ import numpy as np
10
+ from torch.utils.data.dataloader import default_collate
11
+ import cv2
12
+ import matplotlib.pyplot as plt
13
+ from torch.utils.data import DataLoader, Subset, ConcatDataset
14
+ from tqdm import tqdm
15
+ from torch.optim import SGD
16
+ import time
17
+ from torch.optim import AdamW
18
+ import copy
19
+ from torchvision import transforms
20
+
21
+
22
+ object_dict = {
23
+ 0: 'background',
24
+ 1: 'task',
25
+ 2: 'exclusiveGateway',
26
+ 3: 'event',
27
+ 4: 'parallelGateway',
28
+ 5: 'messageEvent',
29
+ 6: 'pool',
30
+ 7: 'lane',
31
+ 8: 'dataObject',
32
+ 9: 'dataStore',
33
+ 10: 'subProcess',
34
+ 11: 'eventBasedGateway',
35
+ 12: 'timerEvent',
36
+ }
37
+
38
+ arrow_dict = {
39
+ 0: 'background',
40
+ 1: 'sequenceFlow',
41
+ 2: 'dataAssociation',
42
+ 3: 'messageFlow',
43
+ }
44
+
45
+ class_dict = {
46
+ 0: 'background',
47
+ 1: 'task',
48
+ 2: 'exclusiveGateway',
49
+ 3: 'event',
50
+ 4: 'parallelGateway',
51
+ 5: 'messageEvent',
52
+ 6: 'pool',
53
+ 7: 'lane',
54
+ 8: 'dataObject',
55
+ 9: 'dataStore',
56
+ 10: 'subProcess',
57
+ 11: 'eventBasedGateway',
58
+ 12: 'timerEvent',
59
+ 13: 'sequenceFlow',
60
+ 14: 'dataAssociation',
61
+ 15: 'messageFlow',
62
+ }
63
+
64
+ def rescale_boxes(scale, boxes):
65
+ for i in range(len(boxes)):
66
+ boxes[i] = [boxes[i][0]*scale,
67
+ boxes[i][1]*scale,
68
+ boxes[i][2]*scale,
69
+ boxes[i][3]*scale]
70
+ return boxes
71
+
72
+ def iou(box1, box2):
73
+ # Calcule l'intersection des deux boîtes englobantes
74
+ inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
75
+ inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
76
+
77
+ # Calcule l'union des deux boîtes englobantes
78
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
79
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
80
+ union_area = box1_area + box2_area - inter_area
81
+
82
+ return inter_area / union_area
83
+
84
+ def proportion_inside(box1, box2):
85
+ # Calculate the intersection of the two bounding boxes
86
+ inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
87
+ inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
88
+
89
+ # Calculate the area of box1
90
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
91
+
92
+ # Calculate the proportion of box1 inside box2
93
+ if box1_area == 0:
94
+ return 0
95
+ proportion = inter_area / box1_area
96
+
97
+ # Ensure the proportion is at most 100%
98
+ return min(proportion, 1.0)
99
+
100
+ def resize_boxes(boxes, original_size, target_size):
101
+ """
102
+ Resizes bounding boxes according to a new image size.
103
+
104
+ Parameters:
105
+ - boxes (np.array): The original bounding boxes as a numpy array of shape [N, 4].
106
+ - original_size (tuple): The original size of the image as (width, height).
107
+ - target_size (tuple): The desired size to resize the image to as (width, height).
108
+
109
+ Returns:
110
+ - np.array: The resized bounding boxes as a numpy array of shape [N, 4].
111
+ """
112
+ orig_width, orig_height = original_size
113
+ target_width, target_height = target_size
114
+
115
+ # Calculate the ratios for width and height
116
+ width_ratio = target_width / orig_width
117
+ height_ratio = target_height / orig_height
118
+
119
+ # Apply the ratios to the bounding boxes
120
+ boxes[:, 0] *= width_ratio
121
+ boxes[:, 1] *= height_ratio
122
+ boxes[:, 2] *= width_ratio
123
+ boxes[:, 3] *= height_ratio
124
+
125
+ return boxes
126
+
127
+ def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: tuple) -> np.ndarray:
128
+ """
129
+ Resize keypoints based on the original and target dimensions of an image.
130
+
131
+ Parameters:
132
+ - keypoints (np.ndarray): The array of keypoints, where each keypoint is represented by its (x, y) coordinates.
133
+ - original_size (tuple): The width and height of the original image (width, height).
134
+ - target_size (tuple): The width and height of the target image (width, height).
135
+
136
+ Returns:
137
+ - np.ndarray: The resized keypoints.
138
+
139
+ Explanation:
140
+ The function calculates the ratio of the target dimensions to the original dimensions.
141
+ It then applies these ratios to the x and y coordinates of each keypoint to scale them
142
+ appropriately to the target image size.
143
+ """
144
+
145
+ orig_width, orig_height = original_size
146
+ target_width, target_height = target_size
147
+
148
+ # Calculate the ratios for width and height scaling
149
+ width_ratio = target_width / orig_width
150
+ height_ratio = target_height / orig_height
151
+
152
+ # Apply the scaling ratios to the x and y coordinates of each keypoint
153
+ keypoints[:, 0] *= width_ratio # Scale x coordinates
154
+ keypoints[:, 1] *= height_ratio # Scale y coordinates
155
+
156
+ return keypoints
157
+
158
+
159
+
160
+ class RandomCrop:
161
+ def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
162
+ self.crop_fraction = crop_fraction
163
+ self.min_objects = min_objects
164
+ self.new_size = new_size
165
+
166
+ def __call__(self, image, target):
167
+ new_w1, new_h1 = self.new_size
168
+ w, h = image.size
169
+ new_w = int(w * self.crop_fraction)
170
+ new_h = int(new_w*new_h1/new_w1)
171
+
172
+ i=0
173
+ for i in range(4):
174
+ if new_h >= h:
175
+ i += 0.05
176
+ new_w = int(w * (self.crop_fraction - i))
177
+ new_h = int(new_w*new_h1/new_w1)
178
+ if new_h < h:
179
+ continue
180
+
181
+ if new_h >= h:
182
+ return image, target
183
+
184
+ boxes = target["boxes"]
185
+ if 'keypoints' in target:
186
+ keypoints = target["keypoints"]
187
+ else:
188
+ keypoints = []
189
+ for i in range(len(boxes)):
190
+ keypoints.append(torch.zeros((2,3)))
191
+
192
+
193
+ # Attempt to find a suitable crop region
194
+ success = False
195
+ for _ in range(100): # Max 100 attempts to find a valid crop
196
+ top = random.randint(0, h - new_h)
197
+ left = random.randint(0, w - new_w)
198
+ crop_region = [left, top, left + new_w, top + new_h]
199
+
200
+ # Check how many objects are fully contained in this region
201
+ contained_boxes = []
202
+ contained_keypoints = []
203
+ for box, kp in zip(boxes, keypoints):
204
+ if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]:
205
+ # Adjust box and keypoints coordinates
206
+ new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]])
207
+ new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0])
208
+ contained_boxes.append(new_box)
209
+ contained_keypoints.append(new_kp)
210
+
211
+ if len(contained_boxes) >= self.min_objects:
212
+ success = True
213
+ break
214
+
215
+ if success:
216
+ # Perform the actual crop
217
+ image = F.crop(image, top, left, new_h, new_w)
218
+ target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4))
219
+ if 'keypoints' in target:
220
+ target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4))
221
+
222
+ return image, target
223
+
224
+
225
+ class RandomFlip:
226
+ def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
227
+ """
228
+ Initializes the RandomFlip with probabilities for flipping.
229
+
230
+ Parameters:
231
+ - h_flip_prob (float): Probability of applying a horizontal flip to the image.
232
+ - v_flip_prob (float): Probability of applying a vertical flip to the image.
233
+ """
234
+ self.h_flip_prob = h_flip_prob
235
+ self.v_flip_prob = v_flip_prob
236
+
237
+ def __call__(self, image, target):
238
+ """
239
+ Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
240
+
241
+ Parameters:
242
+ - image (PIL Image): The image to be flipped.
243
+ - target (dict): The target dictionary containing 'boxes' and 'keypoints'.
244
+
245
+ Returns:
246
+ - PIL Image, dict: The flipped image and its updated target dictionary.
247
+ """
248
+ if random.random() < self.h_flip_prob:
249
+ image = F.hflip(image)
250
+ w, _ = image.size # Get the new width of the image after flip for bounding box adjustment
251
+ # Adjust bounding boxes for horizontal flip
252
+ for i, box in enumerate(target['boxes']):
253
+ xmin, ymin, xmax, ymax = box
254
+ target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32)
255
+
256
+ # Adjust keypoints for horizontal flip
257
+ if 'keypoints' in target:
258
+ new_keypoints = []
259
+ for keypoints_for_object in target['keypoints']:
260
+ flipped_keypoints_for_object = []
261
+ for kp in keypoints_for_object:
262
+ x, y = kp[:2]
263
+ new_x = w - x
264
+ flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:])))
265
+ new_keypoints.append(torch.stack(flipped_keypoints_for_object))
266
+ target['keypoints'] = torch.stack(new_keypoints)
267
+
268
+ if random.random() < self.v_flip_prob:
269
+ image = F.vflip(image)
270
+ _, h = image.size # Get the new height of the image after flip for bounding box adjustment
271
+ # Adjust bounding boxes for vertical flip
272
+ for i, box in enumerate(target['boxes']):
273
+ xmin, ymin, xmax, ymax = box
274
+ target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32)
275
+
276
+ # Adjust keypoints for vertical flip
277
+ if 'keypoints' in target:
278
+ new_keypoints = []
279
+ for keypoints_for_object in target['keypoints']:
280
+ flipped_keypoints_for_object = []
281
+ for kp in keypoints_for_object:
282
+ x, y = kp[:2]
283
+ new_y = h - y
284
+ flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:])))
285
+ new_keypoints.append(torch.stack(flipped_keypoints_for_object))
286
+ target['keypoints'] = torch.stack(new_keypoints)
287
+
288
+ return image, target
289
+
290
+
291
+ class RandomRotate:
292
+ def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
293
+ """
294
+ Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
295
+
296
+ Parameters:
297
+ - max_rotate_deg (int): Maximum degree to rotate the image.
298
+ - rotate_proba (float): Probability of applying rotation to the image.
299
+ """
300
+ self.max_rotate_deg = max_rotate_deg
301
+ self.rotate_proba = rotate_proba
302
+
303
+ def __call__(self, image, target):
304
+ """
305
+ Randomly rotates the image and updates the target data accordingly.
306
+
307
+ Parameters:
308
+ - image (PIL Image): The image to be rotated.
309
+ - target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'.
310
+
311
+ Returns:
312
+ - PIL Image, dict: The rotated image and its updated target dictionary.
313
+ """
314
+ if random.random() < self.rotate_proba:
315
+ angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
316
+ image = F.rotate(image, angle, expand=False, fill=200)
317
+
318
+ # Rotate bounding boxes
319
+ w, h = image.size
320
+ cx, cy = w / 2, h / 2
321
+ boxes = target["boxes"]
322
+ new_boxes = []
323
+ for box in boxes:
324
+ new_box = self.rotate_box(box, angle, cx, cy)
325
+ new_boxes.append(new_box)
326
+ target["boxes"] = torch.stack(new_boxes)
327
+
328
+ # Rotate keypoints
329
+ if 'keypoints' in target:
330
+ new_keypoints = []
331
+ for keypoints in target["keypoints"]:
332
+ new_kp = self.rotate_keypoints(keypoints, angle, cx, cy)
333
+ new_keypoints.append(new_kp)
334
+ target["keypoints"] = torch.stack(new_keypoints)
335
+
336
+ return image, target
337
+
338
+ def rotate_box(self, box, angle, cx, cy):
339
+ """
340
+ Rotates a bounding box by a given angle around the center of the image.
341
+ """
342
+ x1, y1, x2, y2 = box
343
+ corners = torch.tensor([
344
+ [x1, y1],
345
+ [x2, y1],
346
+ [x2, y2],
347
+ [x1, y2]
348
+ ])
349
+ corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1)
350
+ M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
351
+ corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T
352
+ x_ = corners[:, 0]
353
+ y_ = corners[:, 1]
354
+ x_min, x_max = torch.min(x_), torch.max(x_)
355
+ y_min, y_max = torch.min(y_), torch.max(y_)
356
+ return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)
357
+
358
+ def rotate_keypoints(self, keypoints, angle, cx, cy):
359
+ """
360
+ Rotates keypoints by a given angle around the center of the image.
361
+ """
362
+ new_keypoints = []
363
+ for kp in keypoints:
364
+ x, y, v = kp
365
+ point = torch.tensor([x, y, 1])
366
+ M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
367
+ new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point)
368
+ new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32))
369
+ return torch.stack(new_keypoints)
370
+
371
+ def rotate_90_box(box, angle, w, h):
372
+ x1, y1, x2, y2 = box
373
+ if angle == 90:
374
+ return torch.tensor([y1,h-x2,y2,h-x1])
375
+ elif angle == 270 or angle == -90:
376
+ return torch.tensor([w-y2,x1,w-y1,x2])
377
+ else:
378
+ print("angle not supported")
379
+
380
+ def rotate_90_keypoints(kp, angle, w, h):
381
+ # Extract coordinates and visibility from each keypoint tensor
382
+ x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
383
+ x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
384
+ # Swap x and y coordinates for each keypoint
385
+ if angle == 90:
386
+ new = [[y1, h-x1, v1], [y2, h-x2, v2]]
387
+ elif angle == 270 or angle == -90:
388
+ new = [[w-y1, x1, v1], [w-y2, x2, v2]]
389
+
390
+ return torch.tensor(new, dtype=torch.float32)
391
+
392
+
393
+ def rotate_vertical(image, target):
394
+ # Rotate the image and target if the image is vertical
395
+ new_boxes = []
396
+ angle = random.choice([-90,90])
397
+ image = F.rotate(image, angle, expand=True, fill=200)
398
+ for box in target["boxes"]:
399
+ new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
400
+ new_boxes.append(new_box)
401
+ target["boxes"] = torch.stack(new_boxes)
402
+
403
+ if 'keypoints' in target:
404
+ new_kp = []
405
+ for kp in target['keypoints']:
406
+ new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
407
+ new_kp.append(new_key)
408
+ target['keypoints'] = torch.stack(new_kp)
409
+ return image, target
410
+
411
+ class BPMN_Dataset(Dataset):
412
+ def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2, flip_transform=None, rotate_transform=None, new_size=(1333,800),keep_ratio=False,resize=True, model_type='object', rotate_vertical=False):
413
+ self.annotations = annotations
414
+ print(f"Loaded {len(self.annotations)} annotations.")
415
+ self.transform = transform
416
+ self.crop_transform = crop_transform
417
+ self.crop_prob = crop_prob
418
+ self.flip_transform = flip_transform
419
+ self.rotate_transform = rotate_transform
420
+ self.resize = resize
421
+ self.rotate_vertical = rotate_vertical
422
+ self.new_size = new_size
423
+ self.keep_ratio = keep_ratio
424
+ self.model_type = model_type
425
+ if model_type == 'object':
426
+ self.dict = object_dict
427
+ elif model_type == 'arrow':
428
+ self.dict = arrow_dict
429
+ self.rotate_90_proba = rotate_90_proba
430
+
431
+ def __len__(self):
432
+ return len(self.annotations)
433
+
434
+ def __getitem__(self, idx):
435
+ annotation = self.annotations[idx]
436
+ image = annotation.img.convert("RGB")
437
+ boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
438
+ labels_names = [ann for ann in annotation.categories]
439
+
440
+ #only keep the labels, boxes and keypoints that are in the class_dict
441
+ kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
442
+ boxes = boxes[kept_indices]
443
+ labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
444
+
445
+ labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64)
446
+
447
+ # Initialize keypoints tensor
448
+ max_keypoints = 2
449
+ keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32)
450
+
451
+ ii=0
452
+ for i, ann in enumerate(annotation.annotations):
453
+ #only keep the keypoints that are in the kept indices
454
+ if i not in kept_indices:
455
+ continue
456
+ if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
457
+ # Fill the keypoints tensor for this annotation, mark as visible (1)
458
+ kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
459
+ kp = kp[:,:2]
460
+ visible = np.ones((kp.shape[0], 1), dtype=np.float32)
461
+ kp = np.hstack([kp, visible])
462
+ keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
463
+ ii += 1
464
+
465
+ area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
466
+
467
+ if self.model_type == 'object':
468
+ target = {
469
+ "boxes": boxes,
470
+ "labels": labels_id,
471
+ #"area": area,
472
+ #"keypoints": keypoints,
473
+ }
474
+ elif self.model_type == 'arrow':
475
+ target = {
476
+ "boxes": boxes,
477
+ "labels": labels_id,
478
+ #"area": area,
479
+ "keypoints": keypoints,
480
+ }
481
+
482
+ # Randomly apply flip transform
483
+ if self.flip_transform:
484
+ image, target = self.flip_transform(image, target)
485
+
486
+ # Randomly apply rotate transform
487
+ if self.rotate_transform:
488
+ image, target = self.rotate_transform(image, target)
489
+
490
+ # Randomly apply the custom cropping transform
491
+ if self.crop_transform and random.random() < self.crop_prob:
492
+ image, target = self.crop_transform(image, target)
493
+
494
+ # Rotate vertical image
495
+ if self.rotate_vertical and random.random() < self.rotate_90_proba:
496
+ image, target = rotate_vertical(image, target)
497
+
498
+ if self.resize:
499
+ if self.keep_ratio:
500
+ original_size = image.size
501
+ # Calculate scale to fit the new size while maintaining aspect ratio
502
+ scale = min(self.new_size[0] / original_size[0], self.new_size[1] / original_size[1])
503
+ new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
504
+
505
+ target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), (new_scaled_size))
506
+ if 'area' in target:
507
+ target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
508
+
509
+ if 'keypoints' in target:
510
+ for i in range(len(target['keypoints'])):
511
+ target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), (new_scaled_size))
512
+
513
+ # Resize image to new scaled size
514
+ image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
515
+
516
+ # Pad the resized image to make it exactly the desired size
517
+ padding = [0, 0, self.new_size[0] - new_scaled_size[0], self.new_size[1] - new_scaled_size[1]]
518
+ image = F.pad(image, padding, fill=200, padding_mode='constant')
519
+ else:
520
+ target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
521
+ if 'area' in target:
522
+ target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
523
+ if 'keypoints' in target:
524
+ for i in range(len(target['keypoints'])):
525
+ target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
526
+ image = F.resize(image, (self.new_size[1], self.new_size[0]))
527
+
528
+ return self.transform(image), target
529
+
530
+ def collate_fn(batch):
531
+ """
532
+ Custom collation function for DataLoader that handles batches of images and targets.
533
+
534
+ This function ensures that images are properly batched together using PyTorch's default collation,
535
+ while keeping the targets (such as bounding boxes and labels) in a list of dictionaries,
536
+ as each image might have a different number of objects detected.
537
+
538
+ Parameters:
539
+ - batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary.
540
+
541
+ Returns:
542
+ - Tuple containing:
543
+ - Tensor: Batched images.
544
+ - List of dicts: Targets corresponding to each image in the batch.
545
+ """
546
+ images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets.
547
+
548
+ # Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc.
549
+ images = default_collate(images)
550
+
551
+ return images, targets
552
+
553
+
554
+
555
+ def create_loader(new_size,transformation, annotations1, annotations2=None,
556
+ batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
557
+ h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
558
+ seed=42, resize=True, rotate_vertical=False, keep_ratio=False, model_type = 'object'):
559
+ """
560
+ Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
561
+
562
+ Parameters:
563
+ - transformation (callable): Transformation function to apply to each image (e.g., normalization).
564
+ - annotations1 (list): Primary list of annotations.
565
+ - annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
566
+ - batch_size (int): Number of images per batch.
567
+ - crop_prob (float): Probability of applying the crop transformation.
568
+ - crop_fraction (float): Fraction of the original width to use when cropping.
569
+ - min_objects (int): Minimum number of objects required to be within the crop.
570
+ - h_flip_prob (float): Probability of applying horizontal flip.
571
+ - v_flip_prob (float): Probability of applying vertical flip.
572
+ - seed (int): Seed for random number generators for reproducibility.
573
+ - resize (bool): Flag indicating whether to resize images after transformations.
574
+
575
+ Returns:
576
+ - DataLoader: Configured data loader for the dataset.
577
+ """
578
+
579
+ # Initialize custom transformations for cropping and flipping
580
+ custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
581
+ custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
582
+ custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
583
+
584
+ # Create the primary dataset
585
+ dataset = BPMN_Dataset(
586
+ annotations=annotations1,
587
+ transform=transformation,
588
+ crop_transform=custom_crop_transform,
589
+ crop_prob=crop_prob,
590
+ rotate_90_proba=rotate_90_proba,
591
+ flip_transform=custom_flip_transform,
592
+ rotate_transform=custom_rotate_transform,
593
+ rotate_vertical=rotate_vertical,
594
+ new_size=new_size,
595
+ keep_ratio=keep_ratio,
596
+ model_type=model_type,
597
+ resize=resize
598
+ )
599
+
600
+ # Optionally concatenate a second dataset
601
+ if annotations2:
602
+ dataset2 = BPMN_Dataset(
603
+ annotations=annotations2,
604
+ transform=transformation,
605
+ crop_transform=custom_crop_transform,
606
+ crop_prob=crop_prob,
607
+ rotate_90_proba=rotate_90_proba,
608
+ flip_transform=custom_flip_transform,
609
+ rotate_vertical=rotate_vertical,
610
+ new_size=new_size,
611
+ keep_ratio=keep_ratio,
612
+ model_type=model_type,
613
+ resize=resize
614
+ )
615
+ dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets
616
+
617
+ # Set the seed for reproducibility in random operations within transformations and data loading
618
+ random.seed(seed)
619
+ torch.manual_seed(seed)
620
+
621
+ # Create the DataLoader with the dataset
622
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
623
+
624
+ return data_loader
625
+
626
+
627
+
628
+ def write_results(name_model,metrics_list,start_epoch):
629
+ with open('./results/'+ name_model+ '.txt', 'w') as f:
630
+ for i in range(len(metrics_list[0])):
631
+ f.write(f"{i+1+start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n")
632
+
633
+
634
+ def find_other_keypoint(idx, keypoints, boxes):
635
+ box = boxes[idx]
636
+ key1,key2 = keypoints[idx]
637
+ x1, y1, x2, y2 = box
638
+ center = ((x1 + x2) // 2, (y1 + y2) // 2)
639
+ average_keypoint = (key1 + key2) // 2
640
+ #find the opposite keypoint to the center
641
+ if average_keypoint[0] < center[0]:
642
+ x = center[0] + abs(center[0] - average_keypoint[0])
643
+ else:
644
+ x = center[0] - abs(center[0] - average_keypoint[0])
645
+ if average_keypoint[1] < center[1]:
646
+ y = center[1] + abs(center[1] - average_keypoint[1])
647
+ else:
648
+ y = center[1] - abs(center[1] - average_keypoint[1])
649
+ return x, y, average_keypoint[0], average_keypoint[1]
650
+
651
+
652
+ def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
653
+ """
654
+ Filters overlapping boxes based on the Intersection over Union (IoU) metric, keeping only the boxes with the highest scores.
655
+
656
+ Parameters:
657
+ - boxes (np.ndarray): Array of bounding boxes with shape (N, 4), where each row contains [x_min, y_min, x_max, y_max].
658
+ - scores (np.ndarray): Array of scores for each box, reflecting the confidence of detection.
659
+ - labels (np.ndarray): Array of labels corresponding to each box.
660
+ - keypoints (np.ndarray): Array of keypoints associated with each box.
661
+ - iou_threshold (float): Threshold for IoU above which a box is considered overlapping.
662
+
663
+ Returns:
664
+ - tuple: Filtered boxes, scores, labels, and keypoints.
665
+ """
666
+ # Calculate the area of each bounding box to use in IoU calculation.
667
+ areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
668
+
669
+ # Sort the indices of the boxes based on their scores in descending order.
670
+ order = scores.argsort()[::-1]
671
+
672
+ keep = [] # List to store indices of boxes to keep.
673
+
674
+ while order.size > 0:
675
+ # Take the first index (highest score) from the sorted list.
676
+ i = order[0]
677
+ keep.append(i) # Add this index to 'keep' list.
678
+
679
+ # Compute the coordinates of the intersection rectangle.
680
+ xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
681
+ yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
682
+ xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
683
+ yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
684
+
685
+ # Compute the area of the intersection rectangle.
686
+ w = np.maximum(0.0, xx2 - xx1)
687
+ h = np.maximum(0.0, yy2 - yy1)
688
+ inter = w * h
689
+
690
+ # Calculate IoU and find boxes with IoU less than the threshold to keep.
691
+ iou = inter / (areas[i] + areas[order[1:]] - inter)
692
+ inds = np.where(iou <= iou_threshold)[0]
693
+
694
+ # Update the list of box indices to consider in the next iteration.
695
+ order = order[inds + 1] # Skip the first element since it's already included in 'keep'.
696
+
697
+ # Use the indices in 'keep' to select the boxes, scores, labels, and keypoints to return.
698
+ boxes = boxes[keep]
699
+ scores = scores[keep]
700
+ labels = labels[keep]
701
+ keypoints = keypoints[keep]
702
+
703
+ return boxes, scores, labels, keypoints
704
+
705
+
706
+
707
+ def draw_annotations(image,
708
+ target=None,
709
+ prediction=None,
710
+ full_prediction=None,
711
+ text_predictions=None,
712
+ model_dict=class_dict,
713
+ draw_keypoints=False,
714
+ draw_boxes=False,
715
+ draw_text=False,
716
+ draw_links=False,
717
+ draw_twins=False,
718
+ write_class=False,
719
+ write_score=False,
720
+ write_text=False,
721
+ write_idx=False,
722
+ score_threshold=0.4,
723
+ keypoints_correction=False,
724
+ only_print=None,
725
+ axis=False,
726
+ return_image=False,
727
+ new_size=(1333,800),
728
+ resize=False):
729
+ """
730
+ Draws annotations on images including bounding boxes, keypoints, links, and text.
731
+
732
+ Parameters:
733
+ - image (np.array): The image on which annotations will be drawn.
734
+ - target (dict): Ground truth data containing boxes, labels, etc.
735
+ - prediction (dict): Prediction data from a model.
736
+ - full_prediction (dict): Additional detailed prediction data, potentially including relationships.
737
+ - text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
738
+ - model_dict (dict): Mapping from class IDs to class names.
739
+ - draw_keypoints (bool): Flag to draw keypoints.
740
+ - draw_boxes (bool): Flag to draw bounding boxes.
741
+ - draw_text (bool): Flag to draw text annotations.
742
+ - draw_links (bool): Flag to draw links between annotations.
743
+ - draw_twins (bool): Flag to draw twins keypoints.
744
+ - write_class (bool): Flag to write class names near the annotations.
745
+ - write_score (bool): Flag to write scores near the annotations.
746
+ - write_text (bool): Flag to write OCR recognized text.
747
+ - score_threshold (float): Threshold for scores above which annotations will be drawn.
748
+ - only_print (str): Specific class name to filter annotations by.
749
+ - resize (bool): Whether to resize annotations to fit the image size.
750
+ """
751
+
752
+ # Convert image to RGB (if not already in that format)
753
+ if prediction is None:
754
+ image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
755
+
756
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
757
+ image_copy = image.copy()
758
+ scale = max(image.shape[0], image.shape[1]) / 1000
759
+
760
+ # Function to draw bounding boxes and keypoints
761
+ def draw(data,is_prediction=False):
762
+ """ Helper function to draw annotations based on provided data. """
763
+
764
+ for i in range(len(data['boxes'])):
765
+ if is_prediction:
766
+ box = data['boxes'][i].tolist()
767
+ x1, y1, x2, y2 = box
768
+ if resize:
769
+ x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
770
+ score = data['scores'][i].item()
771
+ if score < score_threshold:
772
+ continue
773
+ else:
774
+ box = data['boxes'][i].tolist()
775
+ x1, y1, x2, y2 = box
776
+ if draw_boxes:
777
+ if only_print is not None:
778
+ if data['labels'][i] != list(model_dict.values()).index(only_print):
779
+ continue
780
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2*scale))
781
+ if is_prediction and write_score:
782
+ cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
783
+
784
+ if write_class and 'labels' in data:
785
+ class_id = data['labels'][i].item()
786
+ cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
787
+
788
+ if write_idx:
789
+ cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
790
+
791
+
792
+ # Draw keypoints if available
793
+ if draw_keypoints and 'keypoints' in data:
794
+ if is_prediction and keypoints_correction:
795
+ for idx, (key1, key2) in enumerate(data['keypoints']):
796
+ if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
797
+ list(model_dict.values()).index('messageFlow'),
798
+ list(model_dict.values()).index('dataAssociation')]:
799
+ continue
800
+ # Calculate the Euclidean distance between the two keypoints
801
+ distance = np.linalg.norm(key1[:2] - key2[:2])
802
+
803
+ if distance < 5:
804
+ x_new,y_new, x,y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
805
+ data['keypoints'][idx][0] = torch.tensor([x_new, y_new,1])
806
+ data['keypoints'][idx][1] = torch.tensor([x, y,1])
807
+ print("keypoint has been changed")
808
+ for i in range(len(data['keypoints'])):
809
+ kp = data['keypoints'][i]
810
+ for j in range(kp.shape[0]):
811
+ if is_prediction and data['labels'][i] != list(model_dict.values()).index('sequenceFlow') and data['labels'][i] != list(model_dict.values()).index('messageFlow') and data['labels'][i] != list(model_dict.values()).index('dataAssociation'):
812
+ continue
813
+ if is_prediction:
814
+ score = data['scores'][i]
815
+ if score < score_threshold:
816
+ continue
817
+ x,y,v = np.array(kp[j])
818
+ if resize:
819
+ x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
820
+ if j == 0:
821
+ cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
822
+ else:
823
+ cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
824
+
825
+ # Draw text predictions if available
826
+ if (draw_text or write_text) and text_predictions is not None:
827
+ for i in range(len(text_predictions[0])):
828
+ x1, y1, x2, y2 = text_predictions[0][i]
829
+ text = text_predictions[1][i]
830
+ if resize:
831
+ x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
832
+ if draw_text:
833
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
834
+ if write_text:
835
+ cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
836
+
837
+ def draw_with_links(full_prediction):
838
+ '''Draws links between objects based on the full prediction data.'''
839
+ #check if keypoints detected are the same
840
+ if draw_twins and full_prediction is not None:
841
+ # Pre-calculate indices for performance
842
+ circle_color = (0, 255, 0) # Green color for the circle
843
+ circle_radius = int(10 * scale) # Circle radius scaled by image scale
844
+
845
+ for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
846
+ if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
847
+ list(model_dict.values()).index('messageFlow'),
848
+ list(model_dict.values()).index('dataAssociation')]:
849
+ continue
850
+ # Calculate the Euclidean distance between the two keypoints
851
+ distance = np.linalg.norm(key1[:2] - key2[:2])
852
+ if distance < 10:
853
+ x_new,y_new, x,y = find_other_keypoint(idx,full_prediction)
854
+ cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
855
+ cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
856
+
857
+ # Draw links between objects
858
+ if draw_links==True and full_prediction is not None:
859
+ for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
860
+ if start_idx is None or end_idx is None:
861
+ continue
862
+ start_box = full_prediction['boxes'][start_idx]
863
+ end_box = full_prediction['boxes'][end_idx]
864
+ current_box = full_prediction['boxes'][i]
865
+ # Calculate the center of each bounding box
866
+ start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
867
+ end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
868
+ current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
869
+ # Draw a line between the centers of the connected objects
870
+ cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
871
+ cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
872
+
873
+ i+=1
874
+
875
+ # Draw GT annotations
876
+ if target is not None:
877
+ draw(target, is_prediction=False)
878
+ # Draw predictions
879
+ if prediction is not None:
880
+ #prediction = prediction[0]
881
+ draw(prediction, is_prediction=True)
882
+ # Draw links with full predictions
883
+ if full_prediction is not None:
884
+ draw_with_links(full_prediction)
885
+
886
+ # Display the image
887
+ image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
888
+ plt.figure(figsize=(12, 12))
889
+ plt.imshow(image_copy)
890
+ if axis==False:
891
+ plt.axis('off')
892
+ plt.show()
893
+
894
+ if return_image:
895
+ return image_copy
896
+
897
+ def find_closest_object(keypoint, boxes, labels):
898
+ """
899
+ Find the closest object to a keypoint based on their proximity.
900
+
901
+ Parameters:
902
+ - keypoint (numpy.ndarray): The coordinates of the keypoint.
903
+ - boxes (numpy.ndarray): The bounding boxes of the objects.
904
+
905
+ Returns:
906
+ - int or None: The index of the closest object to the keypoint, or None if no object is found.
907
+ """
908
+ min_distance = float('inf')
909
+ closest_object_idx = None
910
+ # Iterate over each bounding box
911
+ for i, box in enumerate(boxes):
912
+ if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
913
+ list(class_dict.values()).index('messageFlow'),
914
+ list(class_dict.values()).index('dataAssociation'),
915
+ #list(class_dict.values()).index('pool'),
916
+ list(class_dict.values()).index('lane')]:
917
+ continue
918
+ x1, y1, x2, y2 = box
919
+
920
+ top = ((x1+x2)/2, y1)
921
+ bottom = ((x1+x2)/2, y2)
922
+ left = (x1, (y1+y2)/2)
923
+ right = (x2, (y1+y2)/2)
924
+ points = [left, top , right, bottom]
925
+
926
+ # Calculate the distance between the keypoint and the center of the bounding box
927
+ for point in points:
928
+ distance = np.linalg.norm(keypoint[:2] - point)
929
+ # Update the closest object index if this object is closer
930
+ if distance < min_distance:
931
+ min_distance = distance
932
+ closest_object_idx = i
933
+ best_point = point
934
+
935
+ return closest_object_idx, best_point
936
+