π [Fix] some bugs, fit the create_model, device
Browse files- yolo/lazy.py +1 -1
- yolo/model/yolo.py +3 -3
- yolo/tools/solver.py +1 -2
yolo/lazy.py
CHANGED
@@ -22,7 +22,7 @@ def main(cfg: Config):
|
|
22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
23 |
device = torch.device(cfg.device)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
-
model = FastModelLoader(cfg).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
|
|
22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
23 |
device = torch.device(cfg.device)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
+
model = FastModelLoader(cfg, device).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
yolo/model/yolo.py
CHANGED
@@ -43,7 +43,7 @@ class YOLO(nn.Module):
|
|
43 |
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
44 |
|
45 |
# Find in channels
|
46 |
-
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
|
47 |
layer_args["in_channels"] = output_dim[source]
|
48 |
if "Detection" in layer_type:
|
49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
@@ -81,7 +81,7 @@ class YOLO(nn.Module):
|
|
81 |
return output
|
82 |
|
83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
84 |
-
if any(module in layer_type for module in ["Conv", "ELAN", "ADown"]):
|
85 |
return layer_args["out_channels"]
|
86 |
if layer_type == "CBFuse":
|
87 |
return output_dim[source[-1]]
|
@@ -134,7 +134,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: dev
|
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
if os.path.exists(weight_path):
|
137 |
-
model.model.load_state_dict(torch.load(weight_path, map_location=device))
|
138 |
logger.info("β
Success load model weight")
|
139 |
|
140 |
log_model_structure(model.model)
|
|
|
43 |
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
44 |
|
45 |
# Find in channels
|
46 |
+
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
|
47 |
layer_args["in_channels"] = output_dim[source]
|
48 |
if "Detection" in layer_type:
|
49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
|
|
81 |
return output
|
82 |
|
83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
84 |
+
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
|
85 |
return layer_args["out_channels"]
|
86 |
if layer_type == "CBFuse":
|
87 |
return output_dim[source[-1]]
|
|
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
if os.path.exists(weight_path):
|
137 |
+
model.model.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
|
138 |
logger.info("β
Success load model weight")
|
139 |
|
140 |
log_model_structure(model.model)
|
yolo/tools/solver.py
CHANGED
@@ -143,8 +143,7 @@ class ModelTester:
|
|
143 |
break
|
144 |
if not self.save_predict:
|
145 |
continue
|
146 |
-
|
147 |
-
if self.save_predict == False:
|
148 |
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
149 |
img.save(save_image_path)
|
150 |
logger.info(f"πΎ Saved visualize image at {save_image_path}")
|
|
|
143 |
break
|
144 |
if not self.save_predict:
|
145 |
continue
|
146 |
+
if self.save_predict != False:
|
|
|
147 |
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
148 |
img.save(save_image_path)
|
149 |
logger.info(f"πΎ Saved visualize image at {save_image_path}")
|