|
import cv2 |
|
import matplotlib.pyplot as plt |
|
from PIL import ImageColor |
|
from pathlib import Path |
|
import os |
|
|
|
def annotate_image_prediction(image_path, yolo_boxes, class_dic, saving_folder, hex_class_colors=None, show=False, true_count=False, saving_image_name=None, put_title=True, box_thickness=3, font_scale=1, font_thickness=5): |
|
""" |
|
Fonction to label individual images with YOLO predictions |
|
Args: |
|
image_path (str): path to the image to label |
|
yolo_boxes (str): YOLO predicted boxes |
|
class_dic (dict): dictionary with predicted class as key and corresponding label as value |
|
saving_folder (str): folder where to save the annotated image |
|
hex_class_colors (dict, optional): HEX color code dict of the class to plot. Defaults to None. |
|
show (bool, optional): If you want a window of the annotated image to pop up. Defaults to False. |
|
true_count (bool, optional): If you want to display the true total count of cherries. Defaults to None. |
|
saving_image_name (str, optional): Name of the annotated image to save. Defaults to None. |
|
put_title (bool, optional): If you want a title to show in the plot. Defaults to True. |
|
box_thickness (int, optional): Thickness of the bounding boxes to plot. Defaults to 3. |
|
font_scale (int, optional): Font scale of the text of counts to be displayed. Defaults to 1. |
|
font_thickness (int, optional): Font thickness of the text of counts to be displayed. Defaults to 5. |
|
|
|
Returns: |
|
string: saving path of the annotated image |
|
""" |
|
if os.path.isfile(image_path): |
|
Path(saving_folder).mkdir(parents=True, exist_ok=True) |
|
image_file = image_path.split('/')[-1] |
|
if not hex_class_colors: |
|
hex_class_colors = {class_name: (255, 0, 0) for class_name in class_dic.values()} |
|
color_map = {key: ImageColor.getcolor(hex_class_colors[class_dic[key]], 'RGB') for key in [*class_dic]} |
|
|
|
img = cv2.imread(image_path) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
dh, dw, _ = img.shape |
|
|
|
for yolo_box in yolo_boxes: |
|
x1, y1, x2, y2 = yolo_box.xyxy[0] |
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
c = int(yolo_box.cls[0]) |
|
cv2.rectangle(img, (x1, y1), (x2, y2), color_map[c], box_thickness) |
|
|
|
if show: |
|
plt.imshow(img) |
|
plt.show() |
|
img_copy = img.copy() |
|
if put_title: |
|
if true_count: |
|
title = f'Predicted count: {len(yolo_boxes)}, true count: {true_count}, delta: {len(yolo_boxes) - true_count}' |
|
else: |
|
title = f'Predicted count: {len(yolo_boxes)}' |
|
cv2.putText( |
|
img=img_copy, |
|
text=title, |
|
org=(int(0.1 * dw), int(0.1 * dh)), |
|
fontFace=cv2.FONT_HERSHEY_SIMPLEX, |
|
fontScale=font_scale, |
|
thickness=font_thickness, |
|
color=(255,251,5), |
|
) |
|
|
|
if not saving_image_name: |
|
saving_image_name = f'annotated_{image_file}' |
|
Path(saving_folder).mkdir(parents=True, exist_ok=True) |
|
full_saving_path = os.path.join(saving_folder, saving_image_name) |
|
plt.imsave(full_saving_path, img_copy) |
|
else: |
|
full_saving_path = None |
|
print(f'WARNING: {image_path} does not exists') |
|
return full_saving_path |
|
|