henry000 commited on
Commit
6972568
·
1 Parent(s): 868c821

🦺 [Update] load weight, can load trained weight

Browse files
Files changed (1) hide show
  1. yolo/model/yolo.py +3 -0
yolo/model/yolo.py CHANGED
@@ -124,6 +124,9 @@ class YOLO(nn.Module):
124
  """
125
  if isinstance(weights, Path):
126
  weights = torch.load(weights, map_location=torch.device("cpu"))
 
 
 
127
  model_state_dict = self.model.state_dict()
128
 
129
  # TODO1: autoload old version weight
 
124
  """
125
  if isinstance(weights, Path):
126
  weights = torch.load(weights, map_location=torch.device("cpu"))
127
+ if "model_state_dict" in weights:
128
+ weights = weights["model_state_dict"]
129
+
130
  model_state_dict = self.model.state_dict()
131
 
132
  # TODO1: autoload old version weight