Spaces:
Running
Running
from utils import draw_annotations, create_loader, class_dict, resize_boxes, resize_keypoints, find_other_keypoint | |
import cv2 | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
from OCR import group_texts | |
def draw_stream(image, | |
prediction=None, | |
text_predictions=None, | |
class_dict=class_dict, | |
draw_keypoints=False, | |
draw_boxes=False, | |
draw_text=False, | |
draw_links=False, | |
draw_twins=False, | |
draw_grouped_text=False, | |
write_class=False, | |
write_score=False, | |
write_text=False, | |
score_threshold=0.4, | |
write_idx=False, | |
keypoints_correction=False, | |
new_size=(1333, 1333), | |
only_print=None, | |
axis=False, | |
return_image=False, | |
resize=False): | |
""" | |
Draws annotations on images including bounding boxes, keypoints, links, and text. | |
Parameters: | |
- image (np.array): The image on which annotations will be drawn. | |
- target (dict): Ground truth data containing boxes, labels, etc. | |
- prediction (dict): Prediction data from a model. | |
- full_prediction (dict): Additional detailed prediction data, potentially including relationships. | |
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts. | |
- class_dict (dict): Mapping from class IDs to class names. | |
- draw_keypoints (bool): Flag to draw keypoints. | |
- draw_boxes (bool): Flag to draw bounding boxes. | |
- draw_text (bool): Flag to draw text annotations. | |
- draw_links (bool): Flag to draw links between annotations. | |
- draw_twins (bool): Flag to draw twins keypoints. | |
- write_class (bool): Flag to write class names near the annotations. | |
- write_score (bool): Flag to write scores near the annotations. | |
- write_text (bool): Flag to write OCR recognized text. | |
- score_threshold (float): Threshold for scores above which annotations will be drawn. | |
- only_print (str): Specific class name to filter annotations by. | |
- resize (bool): Whether to resize annotations to fit the image size. | |
""" | |
# Convert image to RGB (if not already in that format) | |
if prediction is None: | |
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
image_copy = image.copy() | |
scale = max(image.shape[0], image.shape[1]) / 1000 | |
original_size = (image.shape[0], image.shape[1]) | |
# Calculate scale to fit the new size while maintaining aspect ratio | |
scale_ = min(new_size[0] / original_size[0], new_size[1] / original_size[1]) | |
new_scaled_size = (int(original_size[0] * scale_), int(original_size[1] * scale_)) | |
for i in range(len(prediction['boxes'])): | |
box = prediction['boxes'][i] | |
x1, y1, x2, y2 = box | |
if resize: | |
x1, y1, x2, y2 = resize_boxes(np.array([box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
score = prediction['scores'][i] | |
if score < score_threshold: | |
continue | |
if draw_boxes: | |
if only_print is not None and only_print != 'all': | |
if prediction['labels'][i] != list(class_dict.values()).index(only_print): | |
continue | |
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0), int(2*scale)) | |
if write_score: | |
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2) | |
if write_idx: | |
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2) | |
if write_class and 'labels' in prediction: | |
class_id = prediction['labels'][i] | |
cv2.putText(image_copy, class_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2) | |
# Draw keypoints if available | |
if draw_keypoints and 'keypoints' in prediction: | |
for i in range(len(prediction['keypoints'])): | |
kp = prediction['keypoints'][i] | |
for j in range(kp.shape[0]): | |
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'): | |
continue | |
score = prediction['scores'][i] | |
if score < score_threshold: | |
continue | |
x,y, v = np.array(kp[j]) | |
x, y, v = resize_keypoints(np.array([kp[j]]), (new_scaled_size[1],new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
if j == 0: | |
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1) | |
else: | |
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1) | |
# Draw text predictions if available | |
if (draw_text or write_text) and text_predictions is not None: | |
for i in range(len(text_predictions[0])): | |
x1, y1, x2, y2 = text_predictions[0][i] | |
text = text_predictions[1][i] | |
if resize: | |
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
if draw_text: | |
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale)) | |
if write_text: | |
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2) | |
'''Draws links between objects based on the full prediction data.''' | |
#check if keypoints detected are the same | |
if draw_twins and prediction is not None: | |
# Pre-calculate indices for performance | |
circle_color = (0, 255, 0) # Green color for the circle | |
circle_radius = int(10 * scale) # Circle radius scaled by image scale | |
for idx, (key1, key2) in enumerate(prediction['keypoints']): | |
if prediction['labels'][idx] not in [list(class_dict.values()).index('sequenceFlow'), | |
list(class_dict.values()).index('messageFlow'), | |
list(class_dict.values()).index('dataAssociation')]: | |
continue | |
# Calculate the Euclidean distance between the two keypoints | |
distance = np.linalg.norm(key1[:2] - key2[:2]) | |
if distance < 10: | |
x_new,y_new, x,y = find_other_keypoint(idx,prediction) | |
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1) | |
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1) | |
# Draw links between objects | |
if draw_links==True and prediction is not None: | |
for i, (start_idx, end_idx) in enumerate(prediction['links']): | |
if start_idx is None or end_idx is None: | |
continue | |
start_box = prediction['boxes'][start_idx] | |
start_box = resize_boxes(np.array([start_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
end_box = prediction['boxes'][end_idx] | |
end_box = resize_boxes(np.array([end_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
current_box = prediction['boxes'][i] | |
current_box = resize_boxes(np.array([current_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
# Calculate the center of each bounding box | |
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2) | |
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2) | |
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2) | |
# Draw a line between the centers of the connected objects | |
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale)) | |
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale)) | |
if draw_grouped_text and prediction is not None: | |
task_boxes = task_boxes = [box for i, box in enumerate(prediction['boxes']) if prediction['labels'][i] == list(class_dict.values()).index('task')] | |
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_predictions[0], text_predictions[1], percentage_thresh=1) | |
for i in range(len(info_boxes)): | |
x1, y1, x2, y2 = info_boxes[i] | |
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale)) | |
for i in range(len(sentence_bounding_boxes)): | |
x1,y1,x2,y2 = sentence_bounding_boxes[i] | |
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0] | |
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale)) | |
if return_image: | |
return image_copy | |
else: | |
# Display the image | |
plt.figure(figsize=(12, 12)) | |
plt.imshow(image_copy) | |
if axis==False: | |
plt.axis('off') | |
plt.show() |