henry000 commited on
Commit
78e3679
Β·
1 Parent(s): 3a6d42f

πŸ› [Fix] some bugs, fit the create_model, device

Browse files
Files changed (3) hide show
  1. yolo/lazy.py +1 -1
  2. yolo/model/yolo.py +3 -3
  3. yolo/tools/solver.py +1 -2
yolo/lazy.py CHANGED
@@ -22,7 +22,7 @@ def main(cfg: Config):
22
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
23
  device = torch.device(cfg.device)
24
  if getattr(cfg.task, "fast_inference", False):
25
- model = FastModelLoader(cfg).load_model()
26
  device = torch.device(cfg.device)
27
  else:
28
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
 
22
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
23
  device = torch.device(cfg.device)
24
  if getattr(cfg.task, "fast_inference", False):
25
+ model = FastModelLoader(cfg, device).load_model()
26
  device = torch.device(cfg.device)
27
  else:
28
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
yolo/model/yolo.py CHANGED
@@ -43,7 +43,7 @@ class YOLO(nn.Module):
43
  source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
44
 
45
  # Find in channels
46
- if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
47
  layer_args["in_channels"] = output_dim[source]
48
  if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
@@ -81,7 +81,7 @@ class YOLO(nn.Module):
81
  return output
82
 
83
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
84
- if any(module in layer_type for module in ["Conv", "ELAN", "ADown"]):
85
  return layer_args["out_channels"]
86
  if layer_type == "CBFuse":
87
  return output_dim[source[-1]]
@@ -134,7 +134,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: dev
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
  if os.path.exists(weight_path):
137
- model.model.load_state_dict(torch.load(weight_path, map_location=device))
138
  logger.info("βœ… Success load model weight")
139
 
140
  log_model_structure(model.model)
 
43
  source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
44
 
45
  # Find in channels
46
+ if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
47
  layer_args["in_channels"] = output_dim[source]
48
  if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
 
81
  return output
82
 
83
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
84
+ if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
85
  return layer_args["out_channels"]
86
  if layer_type == "CBFuse":
87
  return output_dim[source[-1]]
 
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
  if os.path.exists(weight_path):
137
+ model.model.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
138
  logger.info("βœ… Success load model weight")
139
 
140
  log_model_structure(model.model)
yolo/tools/solver.py CHANGED
@@ -143,8 +143,7 @@ class ModelTester:
143
  break
144
  if not self.save_predict:
145
  continue
146
-
147
- if self.save_predict == False:
148
  save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
149
  img.save(save_image_path)
150
  logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
 
143
  break
144
  if not self.save_predict:
145
  continue
146
+ if self.save_predict != False:
 
147
  save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
148
  img.save(save_image_path)
149
  logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")