✨ [Finish] model forward function
Browse files- model/module.py +1 -1
- model/yolo.py +25 -16
model/module.py
CHANGED
@@ -297,7 +297,7 @@ class CSPELAN(nn.Module):
|
|
297 |
|
298 |
|
299 |
class Concat(nn.Module):
|
300 |
-
def __init__(self, dim
|
301 |
super(Concat, self).__init__()
|
302 |
self.dim = dim
|
303 |
|
|
|
297 |
|
298 |
|
299 |
class Concat(nn.Module):
|
300 |
+
def __init__(self, dim=1):
|
301 |
super(Concat, self).__init__()
|
302 |
self.dim = dim
|
303 |
|
model/yolo.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch.nn as nn
|
|
|
2 |
from loguru import logger
|
3 |
-
from typing import Dict, Any, List
|
4 |
import inspect
|
5 |
from utils.tools import load_model_cfg
|
6 |
-
|
7 |
from model import module
|
8 |
|
9 |
|
@@ -23,7 +23,6 @@ def get_layer_map():
|
|
23 |
class YOLO(nn.Module):
|
24 |
"""
|
25 |
A preliminary YOLO (You Only Look Once) model class still under development.
|
26 |
-
#TODO: Next: Finish forward proccess
|
27 |
|
28 |
Parameters:
|
29 |
model_cfg: Configuration for the YOLO model. Expected to define the layers,
|
@@ -33,10 +32,8 @@ class YOLO(nn.Module):
|
|
33 |
def __init__(self, model_cfg: Dict[str, Any]):
|
34 |
super(YOLO, self).__init__()
|
35 |
self.nc = model_cfg["nc"]
|
36 |
-
self.layer_map = get_layer_map() #
|
37 |
self.build_model(model_cfg["model"])
|
38 |
-
print(self.model)
|
39 |
-
# raise NotImplementedError("Constructor not implemented.")
|
40 |
|
41 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
42 |
model_list = nn.ModuleList()
|
@@ -44,7 +41,7 @@ class YOLO(nn.Module):
|
|
44 |
layer_indices_by_tag = {}
|
45 |
|
46 |
for arch_name, arch in model_arch.items():
|
47 |
-
logger.info(f"Building model-{arch_name}")
|
48 |
for layer_idx, layer_spec in enumerate(arch, start=1):
|
49 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
50 |
layer_args = layer_info.get("args", {})
|
@@ -56,8 +53,9 @@ class YOLO(nn.Module):
|
|
56 |
layer_args["in_channels"] = output_dim[source]
|
57 |
if "Detect" in layer_type:
|
58 |
layer_args["nc"] = self.nc
|
|
|
59 |
|
60 |
-
layer = self.create_layer(layer_type, **layer_args)
|
61 |
model_list.append(layer)
|
62 |
|
63 |
if "tags" in layer_info:
|
@@ -69,22 +67,32 @@ class YOLO(nn.Module):
|
|
69 |
output_dim.append(out_channels)
|
70 |
self.model = model_list
|
71 |
|
72 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
if "Conv" in layer_type:
|
74 |
return layer_args["out_channels"]
|
|
|
|
|
75 |
if layer_type == "Concat":
|
76 |
return sum(output_dim[idx] for idx in source)
|
77 |
-
if "Pool" in layer_type:
|
78 |
-
return output_dim[source] // 2
|
79 |
-
if layer_type == "UpSample":
|
80 |
-
return output_dim[source] * 2
|
81 |
if layer_type == "IDetect":
|
82 |
return None
|
83 |
|
84 |
-
def create_layer(self, layer_type: str, **kwargs):
|
85 |
-
# Dictionary mapping layer names to actual layer classes
|
86 |
if layer_type in self.layer_map:
|
87 |
-
|
|
|
|
|
88 |
else:
|
89 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
90 |
|
@@ -99,6 +107,7 @@ def get_model(model_cfg: dict) -> YOLO:
|
|
99 |
YOLO: An instance of the model defined by the given configuration.
|
100 |
"""
|
101 |
model = YOLO(model_cfg)
|
|
|
102 |
return model
|
103 |
|
104 |
|
|
|
1 |
import torch.nn as nn
|
2 |
+
import torch
|
3 |
from loguru import logger
|
4 |
+
from typing import Dict, Any, List, Union
|
5 |
import inspect
|
6 |
from utils.tools import load_model_cfg
|
|
|
7 |
from model import module
|
8 |
|
9 |
|
|
|
23 |
class YOLO(nn.Module):
|
24 |
"""
|
25 |
A preliminary YOLO (You Only Look Once) model class still under development.
|
|
|
26 |
|
27 |
Parameters:
|
28 |
model_cfg: Configuration for the YOLO model. Expected to define the layers,
|
|
|
32 |
def __init__(self, model_cfg: Dict[str, Any]):
|
33 |
super(YOLO, self).__init__()
|
34 |
self.nc = model_cfg["nc"]
|
35 |
+
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
36 |
self.build_model(model_cfg["model"])
|
|
|
|
|
37 |
|
38 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
39 |
model_list = nn.ModuleList()
|
|
|
41 |
layer_indices_by_tag = {}
|
42 |
|
43 |
for arch_name, arch in model_arch.items():
|
44 |
+
logger.info(f"🏗️ Building model-{arch_name}")
|
45 |
for layer_idx, layer_spec in enumerate(arch, start=1):
|
46 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
47 |
layer_args = layer_info.get("args", {})
|
|
|
53 |
layer_args["in_channels"] = output_dim[source]
|
54 |
if "Detect" in layer_type:
|
55 |
layer_args["nc"] = self.nc
|
56 |
+
layer_args["ch"] = [output_dim[idx] for idx in source]
|
57 |
|
58 |
+
layer = self.create_layer(layer_type, source, **layer_args)
|
59 |
model_list.append(layer)
|
60 |
|
61 |
if "tags" in layer_info:
|
|
|
67 |
output_dim.append(out_channels)
|
68 |
self.model = model_list
|
69 |
|
70 |
+
def forward(self, x):
|
71 |
+
y = [x]
|
72 |
+
for layer in self.model:
|
73 |
+
if isinstance(layer.source, list):
|
74 |
+
model_input = [y[idx] for idx in layer.source]
|
75 |
+
else:
|
76 |
+
model_input = y[layer.source]
|
77 |
+
x = layer(model_input)
|
78 |
+
y.append(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
82 |
if "Conv" in layer_type:
|
83 |
return layer_args["out_channels"]
|
84 |
+
if layer_type in ["MaxPool", "UpSample"]:
|
85 |
+
return output_dim[source]
|
86 |
if layer_type == "Concat":
|
87 |
return sum(output_dim[idx] for idx in source)
|
|
|
|
|
|
|
|
|
88 |
if layer_type == "IDetect":
|
89 |
return None
|
90 |
|
91 |
+
def create_layer(self, layer_type: str, source: Union[int, list], **kwargs):
|
|
|
92 |
if layer_type in self.layer_map:
|
93 |
+
layer = self.layer_map[layer_type](**kwargs)
|
94 |
+
layer.source = source
|
95 |
+
return layer
|
96 |
else:
|
97 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
98 |
|
|
|
107 |
YOLO: An instance of the model defined by the given configuration.
|
108 |
"""
|
109 |
model = YOLO(model_cfg)
|
110 |
+
logger.info("✅ Success load model")
|
111 |
return model
|
112 |
|
113 |
|