Spaces:
Running
Running
File size: 10,071 Bytes
615e9f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from utils import draw_annotations, create_loader, class_dict, resize_boxes, resize_keypoints, find_other_keypoint
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from OCR import group_texts
def draw_stream(image,
prediction=None,
text_predictions=None,
class_dict=class_dict,
draw_keypoints=False,
draw_boxes=False,
draw_text=False,
draw_links=False,
draw_twins=False,
draw_grouped_text=False,
write_class=False,
write_score=False,
write_text=False,
score_threshold=0.4,
write_idx=False,
keypoints_correction=False,
new_size=(1333, 1333),
only_print=None,
axis=False,
return_image=False,
resize=False):
"""
Draws annotations on images including bounding boxes, keypoints, links, and text.
Parameters:
- image (np.array): The image on which annotations will be drawn.
- target (dict): Ground truth data containing boxes, labels, etc.
- prediction (dict): Prediction data from a model.
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
- class_dict (dict): Mapping from class IDs to class names.
- draw_keypoints (bool): Flag to draw keypoints.
- draw_boxes (bool): Flag to draw bounding boxes.
- draw_text (bool): Flag to draw text annotations.
- draw_links (bool): Flag to draw links between annotations.
- draw_twins (bool): Flag to draw twins keypoints.
- write_class (bool): Flag to write class names near the annotations.
- write_score (bool): Flag to write scores near the annotations.
- write_text (bool): Flag to write OCR recognized text.
- score_threshold (float): Threshold for scores above which annotations will be drawn.
- only_print (str): Specific class name to filter annotations by.
- resize (bool): Whether to resize annotations to fit the image size.
"""
# Convert image to RGB (if not already in that format)
if prediction is None:
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
image_copy = image.copy()
scale = max(image.shape[0], image.shape[1]) / 1000
original_size = (image.shape[0], image.shape[1])
# Calculate scale to fit the new size while maintaining aspect ratio
scale_ = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
new_scaled_size = (int(original_size[0] * scale_), int(original_size[1] * scale_))
for i in range(len(prediction['boxes'])):
box = prediction['boxes'][i]
x1, y1, x2, y2 = box
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
score = prediction['scores'][i]
if score < score_threshold:
continue
if draw_boxes:
if only_print is not None and only_print != 'all':
if prediction['labels'][i] != list(class_dict.values()).index(only_print):
continue
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0), int(2*scale))
if write_score:
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
if write_idx:
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
if write_class and 'labels' in prediction:
class_id = prediction['labels'][i]
cv2.putText(image_copy, class_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
# Draw keypoints if available
if draw_keypoints and 'keypoints' in prediction:
for i in range(len(prediction['keypoints'])):
kp = prediction['keypoints'][i]
for j in range(kp.shape[0]):
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
continue
score = prediction['scores'][i]
if score < score_threshold:
continue
x,y, v = np.array(kp[j])
x, y, v = resize_keypoints(np.array([kp[j]]), (new_scaled_size[1],new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
if j == 0:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
else:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
# Draw text predictions if available
if (draw_text or write_text) and text_predictions is not None:
for i in range(len(text_predictions[0])):
x1, y1, x2, y2 = text_predictions[0][i]
text = text_predictions[1][i]
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
if draw_text:
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
if write_text:
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
'''Draws links between objects based on the full prediction data.'''
#check if keypoints detected are the same
if draw_twins and prediction is not None:
# Pre-calculate indices for performance
circle_color = (0, 255, 0) # Green color for the circle
circle_radius = int(10 * scale) # Circle radius scaled by image scale
for idx, (key1, key2) in enumerate(prediction['keypoints']):
if prediction['labels'][idx] not in [list(class_dict.values()).index('sequenceFlow'),
list(class_dict.values()).index('messageFlow'),
list(class_dict.values()).index('dataAssociation')]:
continue
# Calculate the Euclidean distance between the two keypoints
distance = np.linalg.norm(key1[:2] - key2[:2])
if distance < 10:
x_new,y_new, x,y = find_other_keypoint(idx,prediction)
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
# Draw links between objects
if draw_links==True and prediction is not None:
for i, (start_idx, end_idx) in enumerate(prediction['links']):
if start_idx is None or end_idx is None:
continue
start_box = prediction['boxes'][start_idx]
start_box = resize_boxes(np.array([start_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
end_box = prediction['boxes'][end_idx]
end_box = resize_boxes(np.array([end_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
current_box = prediction['boxes'][i]
current_box = resize_boxes(np.array([current_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
# Calculate the center of each bounding box
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
# Draw a line between the centers of the connected objects
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
if draw_grouped_text and prediction is not None:
task_boxes = task_boxes = [box for i, box in enumerate(prediction['boxes']) if prediction['labels'][i] == list(class_dict.values()).index('task')]
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_predictions[0], text_predictions[1], percentage_thresh=1)
for i in range(len(info_boxes)):
x1, y1, x2, y2 = info_boxes[i]
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
for i in range(len(sentence_bounding_boxes)):
x1,y1,x2,y2 = sentence_bounding_boxes[i]
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
if return_image:
return image_copy
else:
# Display the image
plt.figure(figsize=(12, 12))
plt.imshow(image_copy)
if axis==False:
plt.axis('off')
plt.show() |