henry000 commited on
Commit
1a069e1
Β·
2 Parent(s): c40db62 a3f8ecb

πŸ”€ [Merge] branch 'MODEL' into TEST

Browse files
yolo/config/config.py CHANGED
@@ -1,6 +1,8 @@
1
  from dataclasses import dataclass
2
  from typing import Dict, List, Union
3
 
 
 
4
 
5
  @dataclass
6
  class AnchorConfig:
@@ -100,6 +102,17 @@ class Download:
100
  datasets: Datasets
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
103
  @dataclass
104
  class Config:
105
  model: Model
 
1
  from dataclasses import dataclass
2
  from typing import Dict, List, Union
3
 
4
+ from torch import nn
5
+
6
 
7
  @dataclass
8
  class AnchorConfig:
 
102
  datasets: Datasets
103
 
104
 
105
+ @dataclass
106
+ class YOLOLayer(nn.Module):
107
+ source: Union[int, str, List[int]]
108
+ output: bool
109
+ tags: str
110
+ layer_type: str
111
+
112
+ def __post_init__(self):
113
+ super().__init__()
114
+
115
+
116
  @dataclass
117
  class Config:
118
  model: Model
yolo/model/module.py CHANGED
@@ -24,7 +24,7 @@ class Conv(nn.Module):
24
  ):
25
  super().__init__()
26
  kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
27
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs)
28
  self.bn = nn.BatchNorm2d(out_channels)
29
  self.act = get_activation(activation)
30
 
@@ -49,14 +49,16 @@ class Pool(nn.Module):
49
  class Detection(nn.Module):
50
  """A single YOLO Detection head for detection models"""
51
 
52
- def __init__(self, in_channels: int, num_classes: int, *, reg_max: int = 16, use_group: bool = True):
53
  super().__init__()
54
 
55
  groups = 4 if use_group else 1
56
  anchor_channels = 4 * reg_max
 
 
57
  # TODO: round up head[0] channels or each head?
58
- anchor_neck = max(round_up(in_channels // 4, groups), anchor_channels, 16)
59
- class_neck = max(in_channels, min(num_classes * 2, 128))
60
 
61
  self.anchor_conv = nn.Sequential(
62
  Conv(in_channels, anchor_neck, 3),
@@ -78,8 +80,12 @@ class MultiheadDetection(nn.Module):
78
 
79
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
80
  super().__init__()
 
81
  self.heads = nn.ModuleList(
82
- [Detection(head_in_channels, num_classes, **head_kwargs) for head_in_channels in in_channels]
 
 
 
83
  )
84
 
85
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -118,7 +124,7 @@ class RepNBottleneck(nn.Module):
118
  *,
119
  kernel_size: Tuple[int, int] = (3, 3),
120
  residual: bool = True,
121
- expand: float = 0.5,
122
  **kwargs
123
  ):
124
  super().__init__()
 
24
  ):
25
  super().__init__()
26
  kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
27
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
28
  self.bn = nn.BatchNorm2d(out_channels)
29
  self.act = get_activation(activation)
30
 
 
49
  class Detection(nn.Module):
50
  """A single YOLO Detection head for detection models"""
51
 
52
+ def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True):
53
  super().__init__()
54
 
55
  groups = 4 if use_group else 1
56
  anchor_channels = 4 * reg_max
57
+
58
+ first_neck, in_channels = in_channels
59
  # TODO: round up head[0] channels or each head?
60
+ anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
61
+ class_neck = max(first_neck, min(num_classes * 2, 128))
62
 
63
  self.anchor_conv = nn.Sequential(
64
  Conv(in_channels, anchor_neck, 3),
 
80
 
81
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
82
  super().__init__()
83
+ # TODO: Refactor these parts
84
  self.heads = nn.ModuleList(
85
+ [
86
+ Detection((in_channels[3 * (idx // 3)], in_channel), num_classes, **head_kwargs)
87
+ for idx, in_channel in enumerate(in_channels)
88
+ ]
89
  )
90
 
91
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
 
124
  *,
125
  kernel_size: Tuple[int, int] = (3, 3),
126
  residual: bool = True,
127
+ expand: float = 1.0,
128
  **kwargs
129
  ):
130
  super().__init__()
yolo/model/yolo.py CHANGED
@@ -4,8 +4,9 @@ import torch.nn as nn
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
6
 
7
- from yolo.config.config import Config, Model
8
  from yolo.tools.layer_helper import get_layer_map
 
9
 
10
 
11
  class YOLO(nn.Module):
@@ -21,13 +22,13 @@ class YOLO(nn.Module):
21
  super(YOLO, self).__init__()
22
  self.num_classes = num_classes
23
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
 
24
  self.build_model(model_cfg.model)
 
25
 
26
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
27
- model_list = nn.ModuleList()
28
- output_dim = [3]
29
- layer_indices_by_tag = {}
30
- layer_idx = 1
31
  logger.info(f"🚜 Building YOLO")
32
  for arch_name in model_arch:
33
  logger.info(f" πŸ—οΈ Building {arch_name}")
@@ -36,11 +37,7 @@ class YOLO(nn.Module):
36
  layer_args = layer_info.get("args", {})
37
 
38
  # Get input source
39
- source = layer_info.get("source", -1)
40
- if isinstance(source, str):
41
- source = layer_indices_by_tag[source]
42
- elif isinstance(source, ListConfig):
43
- source = [layer_indices_by_tag[idx] if isinstance(idx, str) else idx for idx in source]
44
 
45
  # Find in channels
46
  if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
@@ -51,29 +48,29 @@ class YOLO(nn.Module):
51
 
52
  # create layers
53
  layer = self.create_layer(layer_type, source, layer_info, **layer_args)
54
- model_list.append(layer)
55
 
56
- if "tags" in layer_info:
57
- if layer_info["tags"] in layer_indices_by_tag:
58
  raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
59
- layer_indices_by_tag[layer_info["tags"]] = layer_idx
60
 
61
  out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
62
  output_dim.append(out_channels)
 
63
  layer_idx += 1
64
 
65
- self.model = model_list
66
-
67
  def forward(self, x):
68
- y = [x]
69
  output = []
70
- for layer in self.model:
71
  if isinstance(layer.source, list):
72
  model_input = [y[idx] for idx in layer.source]
73
  else:
74
  model_input = y[layer.source]
75
  x = layer(model_input)
76
- y.append(x)
 
77
  if layer.output:
78
  output.append(x)
79
  return output
@@ -90,10 +87,23 @@ class YOLO(nn.Module):
90
  if layer_type == "IDetect":
91
  return None
92
 
93
- def create_layer(self, layer_type: str, source: Union[int, list], layer_info, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
94
  if layer_type in self.layer_map:
95
  layer = self.layer_map[layer_type](**kwargs)
 
96
  setattr(layer, "source", source)
 
97
  setattr(layer, "output", layer_info.get("output", False))
98
  setattr(layer, "tags", layer_info.get("tags", None))
99
  return layer
 
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
6
 
7
+ from yolo.config.config import Config, Model, YOLOLayer
8
  from yolo.tools.layer_helper import get_layer_map
9
+ from yolo.tools.log_helper import log_model
10
 
11
 
12
  class YOLO(nn.Module):
 
22
  super(YOLO, self).__init__()
23
  self.num_classes = num_classes
24
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
25
+ self.model: List[YOLOLayer] = nn.ModuleList()
26
  self.build_model(model_cfg.model)
27
+ log_model(self.model)
28
 
29
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
30
+ self.layer_index = {}
31
+ output_dim, layer_idx = [3], 1
 
 
32
  logger.info(f"🚜 Building YOLO")
33
  for arch_name in model_arch:
34
  logger.info(f" πŸ—οΈ Building {arch_name}")
 
37
  layer_args = layer_info.get("args", {})
38
 
39
  # Get input source
40
+ source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
 
 
 
 
41
 
42
  # Find in channels
43
  if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
 
48
 
49
  # create layers
50
  layer = self.create_layer(layer_type, source, layer_info, **layer_args)
51
+ self.model.append(layer)
52
 
53
+ if layer.tags:
54
+ if layer.tags in self.layer_index:
55
  raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
56
+ self.layer_index[layer.tags] = layer_idx
57
 
58
  out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
59
  output_dim.append(out_channels)
60
+ setattr(layer, "out_c", out_channels)
61
  layer_idx += 1
62
 
 
 
63
  def forward(self, x):
64
+ y = {0: x}
65
  output = []
66
+ for index, layer in enumerate(self.model, start=1):
67
  if isinstance(layer.source, list):
68
  model_input = [y[idx] for idx in layer.source]
69
  else:
70
  model_input = y[layer.source]
71
  x = layer(model_input)
72
+ if hasattr(layer, "save"):
73
+ y[index] = x
74
  if layer.output:
75
  output.append(x)
76
  return output
 
87
  if layer_type == "IDetect":
88
  return None
89
 
90
+ def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
91
+ if isinstance(source, ListConfig):
92
+ return [self.get_source_idx(index, layer_idx) for index in source]
93
+ if isinstance(source, str):
94
+ source = self.layer_index[source]
95
+ if source < 0:
96
+ source += layer_idx
97
+ if source > 0:
98
+ setattr(self.model[source - 1], "save", True)
99
+ return source
100
+
101
+ def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
102
  if layer_type in self.layer_map:
103
  layer = self.layer_map[layer_type](**kwargs)
104
+ setattr(layer, "layer_type", layer_type)
105
  setattr(layer, "source", source)
106
+ setattr(layer, "in_c", kwargs.get("in_channels", None))
107
  setattr(layer, "output", layer_info.get("output", False))
108
  setattr(layer, "tags", layer_info.get("tags", None))
109
  return layer
yolo/tools/log_helper.py CHANGED
@@ -12,8 +12,13 @@ Example:
12
  """
13
 
14
  import sys
 
15
 
16
  from loguru import logger
 
 
 
 
17
 
18
 
19
  def custom_logger():
@@ -22,3 +27,24 @@ def custom_logger():
22
  sys.stderr,
23
  format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
 
14
  import sys
15
+ from typing import List
16
 
17
  from loguru import logger
18
+ from rich.console import Console
19
+ from rich.table import Table
20
+
21
+ from yolo.config.config import YOLOLayer
22
 
23
 
24
  def custom_logger():
 
27
  sys.stderr,
28
  format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
29
  )
30
+
31
+
32
+ def log_model(model: List[YOLOLayer]):
33
+ console = Console()
34
+ table = Table(title="Model Layers")
35
+
36
+ table.add_column("Index", justify="center")
37
+ table.add_column("Layer Type", justify="center")
38
+ table.add_column("Tags", justify="center")
39
+ table.add_column("Params", justify="right")
40
+ table.add_column("Channels (IN->OUT)", justify="center")
41
+
42
+ for idx, layer in enumerate(model, start=1):
43
+ layer_param = sum(x.numel() for x in layer.parameters()) # number parameters
44
+ in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
45
+ if in_channels and out_channels:
46
+ channels = f"{in_channels:4} -> {out_channels:4}"
47
+ else:
48
+ channels = "-"
49
+ table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
50
+ console.print(table)
yolo/tools/module_helper.py CHANGED
@@ -31,7 +31,7 @@ def get_activation(activation: str) -> nn.Module:
31
  if isinstance(obj, type) and issubclass(obj, nn.Module)
32
  }
33
  if activation.lower() in activation_map:
34
- return activation_map[activation.lower()]()
35
  else:
36
  raise ValueError(f"Activation function '{activation}' is not found in torch.nn")
37
 
 
31
  if isinstance(obj, type) and issubclass(obj, nn.Module)
32
  }
33
  if activation.lower() in activation_map:
34
+ return activation_map[activation.lower()](inplace=True)
35
  else:
36
  raise ValueError(f"Activation function '{activation}' is not found in torch.nn")
37