sketch-to-BPMN / display.py
BenjiELCA's picture
put online demo
615e9f1
raw
history blame
10.1 kB
from utils import draw_annotations, create_loader, class_dict, resize_boxes, resize_keypoints, find_other_keypoint
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from OCR import group_texts
def draw_stream(image,
prediction=None,
text_predictions=None,
class_dict=class_dict,
draw_keypoints=False,
draw_boxes=False,
draw_text=False,
draw_links=False,
draw_twins=False,
draw_grouped_text=False,
write_class=False,
write_score=False,
write_text=False,
score_threshold=0.4,
write_idx=False,
keypoints_correction=False,
new_size=(1333, 1333),
only_print=None,
axis=False,
return_image=False,
resize=False):
"""
Draws annotations on images including bounding boxes, keypoints, links, and text.
Parameters:
- image (np.array): The image on which annotations will be drawn.
- target (dict): Ground truth data containing boxes, labels, etc.
- prediction (dict): Prediction data from a model.
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
- class_dict (dict): Mapping from class IDs to class names.
- draw_keypoints (bool): Flag to draw keypoints.
- draw_boxes (bool): Flag to draw bounding boxes.
- draw_text (bool): Flag to draw text annotations.
- draw_links (bool): Flag to draw links between annotations.
- draw_twins (bool): Flag to draw twins keypoints.
- write_class (bool): Flag to write class names near the annotations.
- write_score (bool): Flag to write scores near the annotations.
- write_text (bool): Flag to write OCR recognized text.
- score_threshold (float): Threshold for scores above which annotations will be drawn.
- only_print (str): Specific class name to filter annotations by.
- resize (bool): Whether to resize annotations to fit the image size.
"""
# Convert image to RGB (if not already in that format)
if prediction is None:
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
image_copy = image.copy()
scale = max(image.shape[0], image.shape[1]) / 1000
original_size = (image.shape[0], image.shape[1])
# Calculate scale to fit the new size while maintaining aspect ratio
scale_ = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
new_scaled_size = (int(original_size[0] * scale_), int(original_size[1] * scale_))
for i in range(len(prediction['boxes'])):
box = prediction['boxes'][i]
x1, y1, x2, y2 = box
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
score = prediction['scores'][i]
if score < score_threshold:
continue
if draw_boxes:
if only_print is not None and only_print != 'all':
if prediction['labels'][i] != list(class_dict.values()).index(only_print):
continue
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0), int(2*scale))
if write_score:
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
if write_idx:
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
if write_class and 'labels' in prediction:
class_id = prediction['labels'][i]
cv2.putText(image_copy, class_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
# Draw keypoints if available
if draw_keypoints and 'keypoints' in prediction:
for i in range(len(prediction['keypoints'])):
kp = prediction['keypoints'][i]
for j in range(kp.shape[0]):
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
continue
score = prediction['scores'][i]
if score < score_threshold:
continue
x,y, v = np.array(kp[j])
x, y, v = resize_keypoints(np.array([kp[j]]), (new_scaled_size[1],new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
if j == 0:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
else:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
# Draw text predictions if available
if (draw_text or write_text) and text_predictions is not None:
for i in range(len(text_predictions[0])):
x1, y1, x2, y2 = text_predictions[0][i]
text = text_predictions[1][i]
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
if draw_text:
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
if write_text:
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
'''Draws links between objects based on the full prediction data.'''
#check if keypoints detected are the same
if draw_twins and prediction is not None:
# Pre-calculate indices for performance
circle_color = (0, 255, 0) # Green color for the circle
circle_radius = int(10 * scale) # Circle radius scaled by image scale
for idx, (key1, key2) in enumerate(prediction['keypoints']):
if prediction['labels'][idx] not in [list(class_dict.values()).index('sequenceFlow'),
list(class_dict.values()).index('messageFlow'),
list(class_dict.values()).index('dataAssociation')]:
continue
# Calculate the Euclidean distance between the two keypoints
distance = np.linalg.norm(key1[:2] - key2[:2])
if distance < 10:
x_new,y_new, x,y = find_other_keypoint(idx,prediction)
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
# Draw links between objects
if draw_links==True and prediction is not None:
for i, (start_idx, end_idx) in enumerate(prediction['links']):
if start_idx is None or end_idx is None:
continue
start_box = prediction['boxes'][start_idx]
start_box = resize_boxes(np.array([start_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
end_box = prediction['boxes'][end_idx]
end_box = resize_boxes(np.array([end_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
current_box = prediction['boxes'][i]
current_box = resize_boxes(np.array([current_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
# Calculate the center of each bounding box
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
# Draw a line between the centers of the connected objects
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
if draw_grouped_text and prediction is not None:
task_boxes = task_boxes = [box for i, box in enumerate(prediction['boxes']) if prediction['labels'][i] == list(class_dict.values()).index('task')]
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_predictions[0], text_predictions[1], percentage_thresh=1)
for i in range(len(info_boxes)):
x1, y1, x2, y2 = info_boxes[i]
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
for i in range(len(sentence_bounding_boxes)):
x1,y1,x2,y2 = sentence_bounding_boxes[i]
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
if return_image:
return image_copy
else:
# Display the image
plt.figure(figsize=(12, 12))
plt.imshow(image_copy)
if axis==False:
plt.axis('off')
plt.show()