henry000 commited on
Commit
0f9ffa2
·
1 Parent(s): 198ddce

✨ [Add] model logger to print model as table

Browse files
Files changed (2) hide show
  1. yolo/model/yolo.py +5 -0
  2. yolo/tools/log_helper.py +26 -0
yolo/model/yolo.py CHANGED
@@ -6,6 +6,7 @@ from omegaconf import ListConfig, OmegaConf
6
 
7
  from yolo.config.config import Config, Model, YOLOLayer
8
  from yolo.tools.layer_helper import get_layer_map
 
9
 
10
 
11
  class YOLO(nn.Module):
@@ -23,6 +24,7 @@ class YOLO(nn.Module):
23
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
24
  self.model: List[YOLOLayer] = nn.ModuleList()
25
  self.build_model(model_cfg.model)
 
26
 
27
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
28
  self.layer_index = {}
@@ -55,6 +57,7 @@ class YOLO(nn.Module):
55
 
56
  out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
57
  output_dim.append(out_channels)
 
58
  layer_idx += 1
59
 
60
  def forward(self, x):
@@ -98,7 +101,9 @@ class YOLO(nn.Module):
98
  def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
99
  if layer_type in self.layer_map:
100
  layer = self.layer_map[layer_type](**kwargs)
 
101
  setattr(layer, "source", source)
 
102
  setattr(layer, "output", layer_info.get("output", False))
103
  setattr(layer, "tags", layer_info.get("tags", None))
104
  return layer
 
6
 
7
  from yolo.config.config import Config, Model, YOLOLayer
8
  from yolo.tools.layer_helper import get_layer_map
9
+ from yolo.tools.log_helper import log_model
10
 
11
 
12
  class YOLO(nn.Module):
 
24
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
25
  self.model: List[YOLOLayer] = nn.ModuleList()
26
  self.build_model(model_cfg.model)
27
+ log_model(self.model)
28
 
29
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
30
  self.layer_index = {}
 
57
 
58
  out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
59
  output_dim.append(out_channels)
60
+ setattr(layer, "out_c", out_channels)
61
  layer_idx += 1
62
 
63
  def forward(self, x):
 
101
  def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
102
  if layer_type in self.layer_map:
103
  layer = self.layer_map[layer_type](**kwargs)
104
+ setattr(layer, "layer_type", layer_type)
105
  setattr(layer, "source", source)
106
+ setattr(layer, "in_c", kwargs.get("in_channels", None))
107
  setattr(layer, "output", layer_info.get("output", False))
108
  setattr(layer, "tags", layer_info.get("tags", None))
109
  return layer
yolo/tools/log_helper.py CHANGED
@@ -12,8 +12,13 @@ Example:
12
  """
13
 
14
  import sys
 
15
 
16
  from loguru import logger
 
 
 
 
17
 
18
 
19
  def custom_logger():
@@ -22,3 +27,24 @@ def custom_logger():
22
  sys.stderr,
23
  format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
 
14
  import sys
15
+ from typing import List
16
 
17
  from loguru import logger
18
+ from rich.console import Console
19
+ from rich.table import Table
20
+
21
+ from yolo.config.config import YOLOLayer
22
 
23
 
24
  def custom_logger():
 
27
  sys.stderr,
28
  format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
29
  )
30
+
31
+
32
+ def log_model(model: List[YOLOLayer]):
33
+ console = Console()
34
+ table = Table(title="Model Layers")
35
+
36
+ table.add_column("Index", justify="center")
37
+ table.add_column("Layer Type", justify="center")
38
+ table.add_column("Tags", justify="center")
39
+ table.add_column("Params", justify="right")
40
+ table.add_column("Channels (IN->OUT)", justify="center")
41
+
42
+ for idx, layer in enumerate(model, start=1):
43
+ layer_param = sum(x.numel() for x in layer.parameters()) # number parameters
44
+ in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
45
+ if in_channels and out_channels:
46
+ channels = f"{in_channels:4} -> {out_channels:4}"
47
+ else:
48
+ channels = "-"
49
+ table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
50
+ console.print(table)