Merge branch 'MODELv2' into SETUP
Browse files- yolo/model/yolo.py +3 -8
yolo/model/yolo.py
CHANGED
@@ -82,17 +82,14 @@ class YOLO(nn.Module):
|
|
82 |
return output
|
83 |
|
84 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
85 |
-
|
86 |
-
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
|
87 |
return layer_args["out_channels"]
|
88 |
if layer_type == "CBFuse":
|
89 |
return output_dim[source[-1]]
|
90 |
-
if
|
91 |
return output_dim[source]
|
92 |
-
if
|
93 |
return sum(output_dim[idx] for idx in source)
|
94 |
-
if layer_type == "IDetect":
|
95 |
-
return None
|
96 |
|
97 |
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
|
98 |
if isinstance(source, ListConfig):
|
@@ -128,7 +125,6 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str] = True, c
|
|
128 |
Returns:
|
129 |
YOLO: An instance of the model defined by the given configuration.
|
130 |
"""
|
131 |
-
# TODO: "weight_path -> weight = [True|None-False|Path]: True should be default of model name?"
|
132 |
OmegaConf.set_struct(model_cfg, False)
|
133 |
model = YOLO(model_cfg, class_num)
|
134 |
if weight_path:
|
@@ -138,7 +134,6 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str] = True, c
|
|
138 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
139 |
prepare_weight(weight_path=weight_path)
|
140 |
if os.path.exists(weight_path):
|
141 |
-
# TODO: fix map_location
|
142 |
model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
|
143 |
logger.info("β
Success load model & weight")
|
144 |
else:
|
|
|
82 |
return output
|
83 |
|
84 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
85 |
+
if hasattr(layer_args, "out_channels"):
|
|
|
86 |
return layer_args["out_channels"]
|
87 |
if layer_type == "CBFuse":
|
88 |
return output_dim[source[-1]]
|
89 |
+
if isinstance(source, int):
|
90 |
return output_dim[source]
|
91 |
+
if isinstance(source, list):
|
92 |
return sum(output_dim[idx] for idx in source)
|
|
|
|
|
93 |
|
94 |
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
|
95 |
if isinstance(source, ListConfig):
|
|
|
125 |
Returns:
|
126 |
YOLO: An instance of the model defined by the given configuration.
|
127 |
"""
|
|
|
128 |
OmegaConf.set_struct(model_cfg, False)
|
129 |
model = YOLO(model_cfg, class_num)
|
130 |
if weight_path:
|
|
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
if os.path.exists(weight_path):
|
|
|
137 |
model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
|
138 |
logger.info("β
Success load model & weight")
|
139 |
else:
|