henry000 commited on
Commit
2144126
Β·
2 Parent(s): 2cd6792 868c821

πŸ”€ [Merge] branch 'main' into SETUP

Browse files
Files changed (2) hide show
  1. yolo/model/yolo.py +34 -1
  2. yolo/utils/dataset_utils.py +4 -1
yolo/model/yolo.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pathlib import Path
2
  from typing import Dict, List, Union
3
 
@@ -114,6 +115,36 @@ class YOLO(nn.Module):
114
  else:
115
  raise ValueError(f"Unsupported layer type: {layer_type}")
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
119
  """Constructs and returns a model from a Dictionary configuration file.
@@ -129,11 +160,13 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True,
129
  if weight_path:
130
  if weight_path == True:
131
  weight_path = Path("weights") / f"{model_cfg.name}.pt"
 
 
132
  if not weight_path.exists():
133
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
134
  prepare_weight(weight_path=weight_path)
135
  if weight_path.exists():
136
- model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
137
  logger.info("βœ… Success load model & weight")
138
  else:
139
  logger.info("βœ… Success load model")
 
1
+ from collections import OrderedDict
2
  from pathlib import Path
3
  from typing import Dict, List, Union
4
 
 
115
  else:
116
  raise ValueError(f"Unsupported layer type: {layer_type}")
117
 
118
+ def save_load_weights(self, weights: Union[Path, OrderedDict]):
119
+ """
120
+ Update the model's weights with the provided weights.
121
+
122
+ args:
123
+ weights: A OrderedDict containing the new weights.
124
+ """
125
+ if isinstance(weights, Path):
126
+ weights = torch.load(weights, map_location=torch.device("cpu"))
127
+ model_state_dict = self.model.state_dict()
128
+
129
+ # TODO1: autoload old version weight
130
+ # TODO2: weight transform if num_class difference
131
+
132
+ error_dict = {"Mismatch": set(), "Not Found": set()}
133
+ for model_key, model_weight in model_state_dict.items():
134
+ if model_key not in weights:
135
+ error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
136
+ continue
137
+ if model_weight.shape != weights[model_key].shape:
138
+ error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
139
+ continue
140
+ model_state_dict[model_key] = weights[model_key]
141
+
142
+ for error_name, error_set in error_dict.items():
143
+ for weight_name in error_set:
144
+ logger.warning(f"⚠️ Weight {error_name} for key: {'.'.join(weight_name)}")
145
+
146
+ self.model.load_state_dict(model_state_dict)
147
+
148
 
149
  def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
150
  """Constructs and returns a model from a Dictionary configuration file.
 
160
  if weight_path:
161
  if weight_path == True:
162
  weight_path = Path("weights") / f"{model_cfg.name}.pt"
163
+ if isinstance(weight_path, str):
164
+ weight_path = Path(weight_path)
165
  if not weight_path.exists():
166
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
167
  prepare_weight(weight_path=weight_path)
168
  if weight_path.exists():
169
+ model.save_load_weights(weight_path)
170
  logger.info("βœ… Success load model & weight")
171
  else:
172
  logger.info("βœ… Success load model")
yolo/utils/dataset_utils.py CHANGED
@@ -100,7 +100,10 @@ def scale_segmentation(
100
  h, w = image_dimensions["height"], image_dimensions["width"]
101
  for anno in annotations:
102
  category_id = anno["category_id"]
103
- seg_list = [item for sublist in anno["segmentation"] for item in sublist]
 
 
 
104
  scaled_seg_data = (
105
  np.array(seg_list).reshape(-1, 2) / [w, h]
106
  ).tolist() # make the list group in x, y pairs and scaled with image width, height
 
100
  h, w = image_dimensions["height"], image_dimensions["width"]
101
  for anno in annotations:
102
  category_id = anno["category_id"]
103
+ if "segmentation" in anno:
104
+ seg_list = [item for sublist in anno["segmentation"] for item in sublist]
105
+ elif "bbox" in anno:
106
+ seg_list = anno["bbox"]
107
  scaled_seg_data = (
108
  np.array(seg_list).reshape(-1, 2) / [w, h]
109
  ).tolist() # make the list group in x, y pairs and scaled with image width, height