File size: 3,370 Bytes
1197f7d
 
e94b3ff
1197f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e94b3ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import List, Union

import numpy as np
import torch
from loguru import logger
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_pil_image


def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
    """
    Draw bounding boxes on an image.

    Args:
    - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
    - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
      where coordinates are normalized [0, 1].
    """
    # Convert tensor image to PIL Image if necessary
    if isinstance(img, torch.Tensor):
        if img.dim() > 3:
            logger.info("Multi-frame tensor detected, using the first image.")
            img = img[0]
            bboxes = bboxes[0]
        img = to_pil_image(img)

    draw = ImageDraw.Draw(img)
    width, height = img.size
    font = ImageFont.load_default(30)

    for bbox in bboxes:
        class_id, x_min, y_min, x_max, y_max = bbox
        x_min = x_min * width
        x_max = x_max * width
        y_min = y_min * height
        y_max = y_max * height
        shape = [(x_min, y_min), (x_max, y_max)]
        draw.rectangle(shape, outline="red", width=3)
        draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")

    img.save("visualize.jpg")  # Save the image with annotations
    logger.info("Saved visualize image at visualize.png")


def draw_model(*, model_cfg=None, model=None):
    from graphviz import Digraph

    if model_cfg:
        from yolo.model.yolo import get_model

        model = get_model(model_cfg)
    elif model is None:
        raise ValueError("Drawing Object is None")

    model_size = len(model.model)
    model_mat = np.zeros((model_size, model_size), dtype=bool)

    layer_name = []
    for idx, layer in enumerate(model.model):
        layer_name.append(str(type(layer)).split(".")[-1][:-2])
        if isinstance(layer.source, int):
            source = layer.source + (layer.source < 0) * idx
            model_mat[source, idx] = True
        else:
            for source in layer.source:
                source = source + (source < 0) * idx
                model_mat[source, idx] = True

    pattern_list = [("ELAN", 8, 3), ("ELAN", 8, 55), ("MP", 5, 11)]
    pattern_mat = []
    for name, size, position in pattern_list:
        pattern_mat.append(
            (name, size, model_mat[position : position + size, position + 1 : position + 1 + size].copy())
        )

    dot = Digraph(comment="Model Flow Chart")
    node_idx = 0

    for idx in range(model_size):
        for jdx in range(idx, model_size - 7):
            for name, size, pattern in pattern_mat:
                if (model_mat[idx : idx + size, jdx : jdx + size] == pattern).all():
                    layer_name[idx] = name
                    model_mat[idx : idx + size, jdx : jdx + size] = False
                    model_mat[idx, idx + size] = True

        if model_mat[idx].any():
            dot.node(str(idx), f"{node_idx}-{layer_name[idx]}")
            node_idx += 1
        for jdx in range(idx, model_size):
            if model_mat[idx, jdx] == 1:
                dot.edge(str(idx), str(jdx))

    dot.render("Model-arch", format="png", cleanup=True)
    logger.info("🎨 Drawing Model Architecture at Model-arch.png")