henry000 commited on
Commit
cc70c05
Β·
2 Parent(s): 21a413f 2ae070c

Merge branch 'MODELv2' into SETUP

Browse files
Files changed (1) hide show
  1. yolo/model/yolo.py +3 -8
yolo/model/yolo.py CHANGED
@@ -82,17 +82,14 @@ class YOLO(nn.Module):
82
  return output
83
 
84
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
85
- # TODO refactor, check out_channels in layer_args, next check source is list, CBFuse|Concat
86
- if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
87
  return layer_args["out_channels"]
88
  if layer_type == "CBFuse":
89
  return output_dim[source[-1]]
90
- if layer_type in ["Pool", "UpSample"]:
91
  return output_dim[source]
92
- if layer_type == "Concat":
93
  return sum(output_dim[idx] for idx in source)
94
- if layer_type == "IDetect":
95
- return None
96
 
97
  def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
98
  if isinstance(source, ListConfig):
@@ -128,7 +125,6 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str] = True, c
128
  Returns:
129
  YOLO: An instance of the model defined by the given configuration.
130
  """
131
- # TODO: "weight_path -> weight = [True|None-False|Path]: True should be default of model name?"
132
  OmegaConf.set_struct(model_cfg, False)
133
  model = YOLO(model_cfg, class_num)
134
  if weight_path:
@@ -138,7 +134,6 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str] = True, c
138
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
139
  prepare_weight(weight_path=weight_path)
140
  if os.path.exists(weight_path):
141
- # TODO: fix map_location
142
  model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
143
  logger.info("βœ… Success load model & weight")
144
  else:
 
82
  return output
83
 
84
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
85
+ if hasattr(layer_args, "out_channels"):
 
86
  return layer_args["out_channels"]
87
  if layer_type == "CBFuse":
88
  return output_dim[source[-1]]
89
+ if isinstance(source, int):
90
  return output_dim[source]
91
+ if isinstance(source, list):
92
  return sum(output_dim[idx] for idx in source)
 
 
93
 
94
  def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
95
  if isinstance(source, ListConfig):
 
125
  Returns:
126
  YOLO: An instance of the model defined by the given configuration.
127
  """
 
128
  OmegaConf.set_struct(model_cfg, False)
129
  model = YOLO(model_cfg, class_num)
130
  if weight_path:
 
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
  if os.path.exists(weight_path):
 
137
  model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
138
  logger.info("βœ… Success load model & weight")
139
  else: