File size: 3,874 Bytes
7967aab
1197f7d
 
e94b3ff
1197f7d
 
 
 
 
 
b23f927
7967aab
 
 
 
 
a33e03b
b23f927
1197f7d
 
 
 
 
 
 
 
 
 
 
7967aab
1197f7d
 
 
 
 
 
 
 
 
 
b23f927
 
 
 
 
1197f7d
 
 
 
a33e03b
7967aab
 
b23f927
e94b3ff
 
2dcf548
e94b3ff
 
 
a33e03b
e94b3ff
a33e03b
e94b3ff
 
 
2dcf548
e94b3ff
 
2dcf548
 
e94b3ff
2dcf548
 
e94b3ff
 
 
 
 
 
 
 
 
2dcf548
 
 
 
 
 
e94b3ff
 
 
 
 
 
 
 
 
 
 
2dcf548
 
e94b3ff
2dcf548
e94b3ff
b18186e
 
 
 
e94b3ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from typing import List, Union

import numpy as np
import torch
from loguru import logger
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_pil_image


def draw_bboxes(
    img: Union[Image.Image, torch.Tensor],
    bboxes: List[List[Union[int, float]]],
    *,
    scaled_bbox: bool = True,
    save_path: str = "",
    save_name: str = "visualize.png",
):
    """
    Draw bounding boxes on an image.

    Args:
    - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
    - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
      where coordinates are normalized [0, 1].
    """
    # Convert tensor image to PIL Image if necessary
    if isinstance(img, torch.Tensor):
        if img.dim() > 3:
            logger.warning("๐Ÿ” Multi-frame tensor detected, using the first image.")
            img = img[0]
            bboxes = bboxes[0]
        img = to_pil_image(img)

    draw = ImageDraw.Draw(img)
    width, height = img.size
    font = ImageFont.load_default(30)

    for bbox in bboxes:
        class_id, x_min, y_min, x_max, y_max = bbox
        if scaled_bbox:
            x_min = x_min * width
            x_max = x_max * width
            y_min = y_min * height
            y_max = y_max * height
        shape = [(x_min, y_min), (x_max, y_max)]
        draw.rectangle(shape, outline="red", width=3)
        draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")

    save_image_path = os.path.join(save_path, save_name)
    img.save(save_image_path)  # Save the image with annotations
    logger.info(f"๐Ÿ’พ Saved visualize image at {save_image_path}")
    return img


def draw_model(*, model_cfg=None, model=None, v7_base=False):
    from graphviz import Digraph

    if model_cfg:
        from yolo.model.yolo import create_model

        model = create_model(model_cfg)
    elif model is None:
        raise ValueError("Drawing Object is None")

    model_size = len(model.model) + 1
    model_mat = np.zeros((model_size, model_size), dtype=bool)

    layer_name = ["INPUT"]
    for idx, layer in enumerate(model.model, start=1):
        layer_name.append(str(type(layer)).split(".")[-1][:-2])
        if layer.tags is not None:
            layer_name[-1] = f"{layer.tags}-{layer_name[-1]}"
        if isinstance(layer.source, int):
            source = layer.source + (layer.source < 0) * idx
            model_mat[source, idx] = True
        else:
            for source in layer.source:
                source = source + (source < 0) * idx
                model_mat[source, idx] = True

    pattern_mat = []
    if v7_base:
        pattern_list = [("ELAN", 8, 3), ("ELAN", 8, 55), ("MP", 5, 11)]
        for name, size, position in pattern_list:
            pattern_mat.append(
                (name, size, model_mat[position : position + size, position + 1 : position + 1 + size].copy())
            )

    dot = Digraph(comment="Model Flow Chart")
    node_idx = 0

    for idx in range(model_size):
        for jdx in range(idx, model_size - 7):
            for name, size, pattern in pattern_mat:
                if (model_mat[idx : idx + size, jdx : jdx + size] == pattern).all():
                    layer_name[idx] = name
                    model_mat[idx : idx + size, jdx : jdx + size] = False
                    model_mat[idx, idx + size] = True
        dot.node(str(idx), f"{layer_name[idx]}")
        node_idx += 1
        for jdx in range(idx, model_size):
            if model_mat[idx, jdx]:
                dot.edge(str(idx), str(jdx))
    try:
        dot.render("Model-arch", format="png", cleanup=True)
    except:
        logger.info("Warning: Could not find graphviz backend, continue without drawing the model architecture")
    logger.info("๐ŸŽจ Drawing Model Architecture at Model-arch.png")