File size: 1,449 Bytes
23db031 7c11918 23db031 7c11918 23db031 7c11918 23db031 7c11918 23db031 7c11918 23db031 7c11918 23db031 7c11918 23db031 2007b83 7c11918 23db031 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import List, Union
import torch
from loguru import logger
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_pil_image
def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
"""
Draw bounding boxes on an image.
Args:
- img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
- bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
where coordinates are normalized [0, 1].
"""
# Convert tensor image to PIL Image if necessary
if isinstance(img, torch.Tensor):
if img.dim() > 3:
logger.info("Multi-frame tensor detected, using the first image.")
img = img[0]
bboxes = bboxes[0]
img = to_pil_image(img)
draw = ImageDraw.Draw(img)
width, height = img.size
font = ImageFont.load_default(30)
for bbox in bboxes:
class_id, x_min, y_min, x_max, y_max = bbox
x_min = x_min * width
x_max = x_max * width
y_min = y_min * height
y_max = y_max * height
shape = [(x_min, y_min), (x_max, y_max)]
draw.rectangle(shape, outline="red", width=3)
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
img.save("visualize.jpg") # Save the image with annotations
logger.info("Saved visualize image at visualize.png")
|