File size: 4,355 Bytes
7967aab
2c1f270
 
1197f7d
e94b3ff
1197f7d
 
 
 
 
 
b23f927
7967aab
 
 
200b5c1
b23f927
1197f7d
 
 
 
 
 
 
 
 
 
 
3b24b35
64acfd1
1197f7d
 
64acfd1
 
2c1f270
3b24b35
2c1f270
64acfd1
2c1f270
64acfd1
1197f7d
 
2c1f270
3b24b35
 
2c1f270
9000502
2c1f270
9000502
 
3b24b35
200b5c1
3b24b35
2c1f270
 
 
64acfd1
2c1f270
 
 
 
1197f7d
b23f927
e94b3ff
 
2dcf548
e94b3ff
 
 
a33e03b
e94b3ff
a33e03b
e94b3ff
 
 
2dcf548
e94b3ff
 
2dcf548
 
e94b3ff
2dcf548
 
e94b3ff
 
 
 
 
 
 
 
 
2dcf548
 
 
 
 
 
e94b3ff
 
 
 
 
 
 
 
 
 
 
2dcf548
 
e94b3ff
2dcf548
e94b3ff
b18186e
 
635f41a
b18186e
635f41a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import random
from typing import List, Optional, Union

import numpy as np
import torch
from loguru import logger
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_pil_image


def draw_bboxes(
    img: Union[Image.Image, torch.Tensor],
    bboxes: List[List[Union[int, float]]],
    *,
    idx2label: Optional[list] = None,
):
    """
    Draw bounding boxes on an image.

    Args:
    - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
    - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
      where coordinates are normalized [0, 1].
    """
    # Convert tensor image to PIL Image if necessary
    if isinstance(img, torch.Tensor):
        if img.dim() > 3:
            logger.warning("🔍 >3 dimension tensor detected, using the 0-idx image.")
            img = img[0]
        img = to_pil_image(img)

    img, bboxes = img.copy(), bboxes[0]
    label_size = img.size[1] / 30
    draw = ImageDraw.Draw(img, "RGBA")

    try:
        font = ImageFont.truetype("arial.ttf", label_size)
    except IOError:
        font = ImageFont.load_default(label_size)

    for bbox in bboxes:
        class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
        bbox = [(x_min, y_min), (x_max, y_max)]

        random.seed(int(class_id))
        color_map = (random.randint(0, 200), random.randint(0, 200), random.randint(0, 200))

        draw.rounded_rectangle(bbox, outline=(*color_map, 200), radius=5, width=2)
        draw.rounded_rectangle(bbox, fill=(*color_map, 100), radius=5)

        class_text = str(idx2label[int(class_id)] if idx2label else int(class_id))
        label_text = f"{class_text}" + (f" {conf[0]: .0%}" if conf else "")

        text_bbox = font.getbbox(label_text)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = (text_bbox[3] - text_bbox[1]) * 1.5

        text_background = [(x_min, y_min), (x_min + text_width, y_min + text_height)]
        draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
        draw.text((x_min, y_min), label_text, fill="white", font=font)

    return img


def draw_model(*, model_cfg=None, model=None, v7_base=False):
    from graphviz import Digraph

    if model_cfg:
        from yolo.model.yolo import create_model

        model = create_model(model_cfg)
    elif model is None:
        raise ValueError("Drawing Object is None")

    model_size = len(model.model) + 1
    model_mat = np.zeros((model_size, model_size), dtype=bool)

    layer_name = ["INPUT"]
    for idx, layer in enumerate(model.model, start=1):
        layer_name.append(str(type(layer)).split(".")[-1][:-2])
        if layer.tags is not None:
            layer_name[-1] = f"{layer.tags}-{layer_name[-1]}"
        if isinstance(layer.source, int):
            source = layer.source + (layer.source < 0) * idx
            model_mat[source, idx] = True
        else:
            for source in layer.source:
                source = source + (source < 0) * idx
                model_mat[source, idx] = True

    pattern_mat = []
    if v7_base:
        pattern_list = [("ELAN", 8, 3), ("ELAN", 8, 55), ("MP", 5, 11)]
        for name, size, position in pattern_list:
            pattern_mat.append(
                (name, size, model_mat[position : position + size, position + 1 : position + 1 + size].copy())
            )

    dot = Digraph(comment="Model Flow Chart")
    node_idx = 0

    for idx in range(model_size):
        for jdx in range(idx, model_size - 7):
            for name, size, pattern in pattern_mat:
                if (model_mat[idx : idx + size, jdx : jdx + size] == pattern).all():
                    layer_name[idx] = name
                    model_mat[idx : idx + size, jdx : jdx + size] = False
                    model_mat[idx, idx + size] = True
        dot.node(str(idx), f"{layer_name[idx]}")
        node_idx += 1
        for jdx in range(idx, model_size):
            if model_mat[idx, jdx]:
                dot.edge(str(idx), str(jdx))
    try:
        dot.render("Model-arch", format="png", cleanup=True)
        logger.info("🎨 Drawing Model Architecture at Model-arch.png")
    except:
        logger.warning("⚠️ Could not find graphviz backend, continue without drawing the model architecture")