|
from typing import List, Union |
|
|
|
import torch |
|
from loguru import logger |
|
from PIL import Image, ImageDraw, ImageFont |
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
|
|
def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]): |
|
""" |
|
Draw bounding boxes on an image. |
|
|
|
Args: |
|
- img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes. |
|
- bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max], |
|
where coordinates are normalized [0, 1]. |
|
""" |
|
|
|
if isinstance(img, torch.Tensor): |
|
if img.dim() > 3: |
|
logger.info("Multi-frame tensor detected, using the first image.") |
|
img = img[0] |
|
bboxes = bboxes[0] |
|
img = to_pil_image(img) |
|
|
|
draw = ImageDraw.Draw(img) |
|
width, height = img.size |
|
font = ImageFont.load_default(30) |
|
|
|
for bbox in bboxes: |
|
class_id, x_min, y_min, x_max, y_max = bbox |
|
x_min = x_min * width |
|
x_max = x_max * width |
|
y_min = y_min * height |
|
y_max = y_max * height |
|
shape = [(x_min, y_min), (x_max, y_max)] |
|
draw.rectangle(shape, outline="red", width=3) |
|
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue") |
|
|
|
img.save("visualize.jpg") |
|
logger.info("Saved visualize image at visualize.png") |
|
|