henry000 commited on
Commit
e94b3ff
·
1 Parent(s): b44d6bb

🎨 [Add] Visualization of YOLO model

Browse files
Files changed (2) hide show
  1. examples/example_train.py +2 -0
  2. yolo/utils/drawer.py +54 -0
examples/example_train.py CHANGED
@@ -13,6 +13,7 @@ from yolo.model.yolo import get_model
13
  from yolo.tools.log_helper import custom_logger
14
  from yolo.tools.trainer import Trainer
15
  from yolo.utils.dataloader import get_dataloader
 
16
  from yolo.utils.get_dataset import prepare_dataset
17
 
18
 
@@ -23,6 +24,7 @@ def main(cfg: Config):
23
 
24
  dataloader = get_dataloader(cfg)
25
  model = get_model(cfg.model)
 
26
  # TODO: get_device or rank, for DDP mode
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
 
 
13
  from yolo.tools.log_helper import custom_logger
14
  from yolo.tools.trainer import Trainer
15
  from yolo.utils.dataloader import get_dataloader
16
+ from yolo.utils.drawer import draw_model
17
  from yolo.utils.get_dataset import prepare_dataset
18
 
19
 
 
24
 
25
  dataloader = get_dataloader(cfg)
26
  model = get_model(cfg.model)
27
+ draw_model(model=model)
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
yolo/utils/drawer.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import List, Union
2
 
 
3
  import torch
4
  from loguru import logger
5
  from PIL import Image, ImageDraw, ImageFont
@@ -39,3 +40,56 @@ def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[i
39
 
40
  img.save("visualize.jpg") # Save the image with annotations
41
  logger.info("Saved visualize image at visualize.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List, Union
2
 
3
+ import numpy as np
4
  import torch
5
  from loguru import logger
6
  from PIL import Image, ImageDraw, ImageFont
 
40
 
41
  img.save("visualize.jpg") # Save the image with annotations
42
  logger.info("Saved visualize image at visualize.png")
43
+
44
+
45
+ def draw_model(*, model_cfg=None, model=None):
46
+ from graphviz import Digraph
47
+
48
+ if model_cfg:
49
+ from yolo.model.yolo import get_model
50
+
51
+ model = get_model(model_cfg)
52
+ elif model is None:
53
+ raise ValueError("Drawing Object is None")
54
+
55
+ model_size = len(model.model)
56
+ model_mat = np.zeros((model_size, model_size), dtype=bool)
57
+
58
+ layer_name = []
59
+ for idx, layer in enumerate(model.model):
60
+ layer_name.append(str(type(layer)).split(".")[-1][:-2])
61
+ if isinstance(layer.source, int):
62
+ source = layer.source + (layer.source < 0) * idx
63
+ model_mat[source, idx] = True
64
+ else:
65
+ for source in layer.source:
66
+ source = source + (source < 0) * idx
67
+ model_mat[source, idx] = True
68
+
69
+ pattern_list = [("ELAN", 8, 3), ("ELAN", 8, 55), ("MP", 5, 11)]
70
+ pattern_mat = []
71
+ for name, size, position in pattern_list:
72
+ pattern_mat.append(
73
+ (name, size, model_mat[position : position + size, position + 1 : position + 1 + size].copy())
74
+ )
75
+
76
+ dot = Digraph(comment="Model Flow Chart")
77
+ node_idx = 0
78
+
79
+ for idx in range(model_size):
80
+ for jdx in range(idx, model_size - 7):
81
+ for name, size, pattern in pattern_mat:
82
+ if (model_mat[idx : idx + size, jdx : jdx + size] == pattern).all():
83
+ layer_name[idx] = name
84
+ model_mat[idx : idx + size, jdx : jdx + size] = False
85
+ model_mat[idx, idx + size] = True
86
+
87
+ if model_mat[idx].any():
88
+ dot.node(str(idx), f"{node_idx}-{layer_name[idx]}")
89
+ node_idx += 1
90
+ for jdx in range(idx, model_size):
91
+ if model_mat[idx, jdx] == 1:
92
+ dot.edge(str(idx), str(jdx))
93
+
94
+ dot.render("Model-arch", format="png", cleanup=True)
95
+ logger.info("🎨 Drawing Model Architecture at Model-arch.png")