henry000 commited on
Commit
542860e
·
1 Parent(s): bb5c520

✨ [Finish] model forward function

Browse files
Files changed (2) hide show
  1. model/module.py +1 -1
  2. model/yolo.py +25 -16
model/module.py CHANGED
@@ -297,7 +297,7 @@ class CSPELAN(nn.Module):
297
 
298
 
299
  class Concat(nn.Module):
300
- def __init__(self, dim=-1):
301
  super(Concat, self).__init__()
302
  self.dim = dim
303
 
 
297
 
298
 
299
  class Concat(nn.Module):
300
+ def __init__(self, dim=1):
301
  super(Concat, self).__init__()
302
  self.dim = dim
303
 
model/yolo.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch.nn as nn
 
2
  from loguru import logger
3
- from typing import Dict, Any, List
4
  import inspect
5
  from utils.tools import load_model_cfg
6
-
7
  from model import module
8
 
9
 
@@ -23,7 +23,6 @@ def get_layer_map():
23
  class YOLO(nn.Module):
24
  """
25
  A preliminary YOLO (You Only Look Once) model class still under development.
26
- #TODO: Next: Finish forward proccess
27
 
28
  Parameters:
29
  model_cfg: Configuration for the YOLO model. Expected to define the layers,
@@ -33,10 +32,8 @@ class YOLO(nn.Module):
33
  def __init__(self, model_cfg: Dict[str, Any]):
34
  super(YOLO, self).__init__()
35
  self.nc = model_cfg["nc"]
36
- self.layer_map = get_layer_map() # Dynamically get the mapping
37
  self.build_model(model_cfg["model"])
38
- print(self.model)
39
- # raise NotImplementedError("Constructor not implemented.")
40
 
41
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
42
  model_list = nn.ModuleList()
@@ -44,7 +41,7 @@ class YOLO(nn.Module):
44
  layer_indices_by_tag = {}
45
 
46
  for arch_name, arch in model_arch.items():
47
- logger.info(f"Building model-{arch_name}")
48
  for layer_idx, layer_spec in enumerate(arch, start=1):
49
  layer_type, layer_info = next(iter(layer_spec.items()))
50
  layer_args = layer_info.get("args", {})
@@ -56,8 +53,9 @@ class YOLO(nn.Module):
56
  layer_args["in_channels"] = output_dim[source]
57
  if "Detect" in layer_type:
58
  layer_args["nc"] = self.nc
 
59
 
60
- layer = self.create_layer(layer_type, **layer_args)
61
  model_list.append(layer)
62
 
63
  if "tags" in layer_info:
@@ -69,22 +67,32 @@ class YOLO(nn.Module):
69
  output_dim.append(out_channels)
70
  self.model = model_list
71
 
72
- def get_out_channels(self, layer_type, layer_args, output_dim, source):
 
 
 
 
 
 
 
 
 
 
 
73
  if "Conv" in layer_type:
74
  return layer_args["out_channels"]
 
 
75
  if layer_type == "Concat":
76
  return sum(output_dim[idx] for idx in source)
77
- if "Pool" in layer_type:
78
- return output_dim[source] // 2
79
- if layer_type == "UpSample":
80
- return output_dim[source] * 2
81
  if layer_type == "IDetect":
82
  return None
83
 
84
- def create_layer(self, layer_type: str, **kwargs):
85
- # Dictionary mapping layer names to actual layer classes
86
  if layer_type in self.layer_map:
87
- return self.layer_map[layer_type](**kwargs)
 
 
88
  else:
89
  raise ValueError(f"Unsupported layer type: {layer_type}")
90
 
@@ -99,6 +107,7 @@ def get_model(model_cfg: dict) -> YOLO:
99
  YOLO: An instance of the model defined by the given configuration.
100
  """
101
  model = YOLO(model_cfg)
 
102
  return model
103
 
104
 
 
1
  import torch.nn as nn
2
+ import torch
3
  from loguru import logger
4
+ from typing import Dict, Any, List, Union
5
  import inspect
6
  from utils.tools import load_model_cfg
 
7
  from model import module
8
 
9
 
 
23
  class YOLO(nn.Module):
24
  """
25
  A preliminary YOLO (You Only Look Once) model class still under development.
 
26
 
27
  Parameters:
28
  model_cfg: Configuration for the YOLO model. Expected to define the layers,
 
32
  def __init__(self, model_cfg: Dict[str, Any]):
33
  super(YOLO, self).__init__()
34
  self.nc = model_cfg["nc"]
35
+ self.layer_map = get_layer_map() # Get the map Dict[str: Module]
36
  self.build_model(model_cfg["model"])
 
 
37
 
38
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
39
  model_list = nn.ModuleList()
 
41
  layer_indices_by_tag = {}
42
 
43
  for arch_name, arch in model_arch.items():
44
+ logger.info(f"🏗️ Building model-{arch_name}")
45
  for layer_idx, layer_spec in enumerate(arch, start=1):
46
  layer_type, layer_info = next(iter(layer_spec.items()))
47
  layer_args = layer_info.get("args", {})
 
53
  layer_args["in_channels"] = output_dim[source]
54
  if "Detect" in layer_type:
55
  layer_args["nc"] = self.nc
56
+ layer_args["ch"] = [output_dim[idx] for idx in source]
57
 
58
+ layer = self.create_layer(layer_type, source, **layer_args)
59
  model_list.append(layer)
60
 
61
  if "tags" in layer_info:
 
67
  output_dim.append(out_channels)
68
  self.model = model_list
69
 
70
+ def forward(self, x):
71
+ y = [x]
72
+ for layer in self.model:
73
+ if isinstance(layer.source, list):
74
+ model_input = [y[idx] for idx in layer.source]
75
+ else:
76
+ model_input = y[layer.source]
77
+ x = layer(model_input)
78
+ y.append(x)
79
+ return x
80
+
81
+ def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
82
  if "Conv" in layer_type:
83
  return layer_args["out_channels"]
84
+ if layer_type in ["MaxPool", "UpSample"]:
85
+ return output_dim[source]
86
  if layer_type == "Concat":
87
  return sum(output_dim[idx] for idx in source)
 
 
 
 
88
  if layer_type == "IDetect":
89
  return None
90
 
91
+ def create_layer(self, layer_type: str, source: Union[int, list], **kwargs):
 
92
  if layer_type in self.layer_map:
93
+ layer = self.layer_map[layer_type](**kwargs)
94
+ layer.source = source
95
+ return layer
96
  else:
97
  raise ValueError(f"Unsupported layer type: {layer_type}")
98
 
 
107
  YOLO: An instance of the model defined by the given configuration.
108
  """
109
  model = YOLO(model_cfg)
110
+ logger.info("✅ Success load model")
111
  return model
112
 
113