henry000 commited on
Commit
f95a3d7
Β·
1 Parent(s): 2dd2ae5

πŸ› [Update] some bug or vaiable name in Vec2Box

Browse files
yolo/__init__.py CHANGED
@@ -3,7 +3,7 @@ from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
- from yolo.utils.bounding_box_utils import bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
 
@@ -13,6 +13,7 @@ all = [
13
  "custom_logger",
14
  "validate_log_directory",
15
  "draw_bboxes",
 
16
  "bbox_nms",
17
  "AugmentationComposer",
18
  "create_dataloader",
 
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
+ from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
 
 
13
  "custom_logger",
14
  "validate_log_directory",
15
  "draw_bboxes",
16
+ "Vec2Box",
17
  "bbox_nms",
18
  "AugmentationComposer",
19
  "create_dataloader",
yolo/config/model/v9-c.yaml CHANGED
@@ -68,6 +68,14 @@ model:
68
  args: {out_channels: 512, part_channels: 512}
69
  tags: P5
70
 
 
 
 
 
 
 
 
 
71
  auxiliary:
72
  - CBLinear:
73
  source: B3
@@ -123,11 +131,3 @@ model:
123
  args:
124
  reg_max: ${model.anchor.reg_max}
125
  output: True
126
-
127
- detection:
128
- - MultiheadDetection:
129
- source: [P3, P4, P5]
130
- tags: Main
131
- args:
132
- reg_max: ${model.anchor.reg_max}
133
- output: True
 
68
  args: {out_channels: 512, part_channels: 512}
69
  tags: P5
70
 
71
+ detection:
72
+ - MultiheadDetection:
73
+ source: [P3, P4, P5]
74
+ tags: Main
75
+ args:
76
+ reg_max: ${model.anchor.reg_max}
77
+ output: True
78
+
79
  auxiliary:
80
  - CBLinear:
81
  source: B3
 
131
  args:
132
  reg_max: ${model.anchor.reg_max}
133
  output: True
 
 
 
 
 
 
 
 
yolo/tools/format_converters.py CHANGED
@@ -1,12 +1,13 @@
1
  def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
2
  # TODO: need to refactor
 
3
  for idx in range(model_size):
4
  new_list, old_list = [], []
5
  for weight_name, weight_value in new_state_dict.items():
6
  if weight_name.split(".")[0] == str(idx):
7
  new_list.append((weight_name, None))
8
  for weight_name, weight_value in old_state_dict.items():
9
- if f"model.{idx+1}." in weight_name:
10
  old_list.append((weight_name, weight_value))
11
  if len(new_list) == len(old_list):
12
  for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
@@ -17,7 +18,8 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
17
  continue
18
  _, _, conv_name, conv_idx, *details = weight_name.split(".")
19
  if conv_name == "cv4" or conv_name == "cv5":
20
- layer_idx = 38
 
21
  else:
22
  layer_idx = 37
23
 
 
1
  def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
2
  # TODO: need to refactor
3
+ shift = 1
4
  for idx in range(model_size):
5
  new_list, old_list = [], []
6
  for weight_name, weight_value in new_state_dict.items():
7
  if weight_name.split(".")[0] == str(idx):
8
  new_list.append((weight_name, None))
9
  for weight_name, weight_value in old_state_dict.items():
10
+ if f"model.{idx+shift}." in weight_name:
11
  old_list.append((weight_name, weight_value))
12
  if len(new_list) == len(old_list):
13
  for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
 
18
  continue
19
  _, _, conv_name, conv_idx, *details = weight_name.split(".")
20
  if conv_name == "cv4" or conv_name == "cv5":
21
+ layer_idx = 22
22
+ shift = 2
23
  else:
24
  layer_idx = 37
25
 
yolo/tools/solver.py CHANGED
@@ -32,7 +32,7 @@ class ModelTrainer:
32
  self.num_epochs = cfg.task.epoch
33
 
34
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
35
- self.validator = ModelValidator(cfg.task.validation, model, save_path, device, self.progress)
36
 
37
  if getattr(train_cfg.ema, "enabled", False):
38
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
@@ -40,14 +40,14 @@ class ModelTrainer:
40
  self.ema = None
41
  self.scaler = GradScaler()
42
 
43
- def train_one_batch(self, data: Tensor, targets: Tensor):
44
- data, targets = data.to(self.device), targets.to(self.device)
45
  self.optimizer.zero_grad()
46
 
47
  with autocast():
48
- outputs = self.model(data)
49
- aux_predicts = self.vec2box(outputs["AUX"])
50
- main_predicts = self.vec2box(outputs["Main"])
51
  loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
52
 
53
  self.scaler.scale(loss).backward()
@@ -60,8 +60,8 @@ class ModelTrainer:
60
  self.model.train()
61
  total_loss = 0
62
 
63
- for data, targets in dataloader:
64
- loss, loss_each = self.train_one_batch(data, targets)
65
 
66
  total_loss += loss
67
  self.progress.one_batch(loss_each)
@@ -111,14 +111,15 @@ class ModelTester:
111
 
112
  def solve(self, dataloader: StreamDataLoader):
113
  logger.info("πŸ‘€ Start Inference!")
114
-
 
115
  try:
116
  for idx, images in enumerate(dataloader):
117
  images = images.to(self.device)
118
  with torch.no_grad():
119
- outputs = self.model(images)
120
- outputs = self.vec2box(outputs["Main"])
121
- nms_out = bbox_nms(outputs[0], outputs[2], self.nms)
122
  draw_bboxes(
123
  images[0],
124
  nms_out[0],
@@ -141,15 +142,18 @@ class ModelValidator:
141
  self,
142
  validation_cfg: ValidationConfig,
143
  model: YOLO,
 
144
  save_path: str,
145
  device,
146
  # TODO: think Progress?
147
  progress: ProgressTracker,
148
  ):
149
  self.model = model
 
150
  self.device = device
151
  self.progress = progress
152
  self.save_path = save_path
 
153
  self.nms = validation_cfg.nms
154
 
155
  def solve(self, dataloader):
@@ -159,11 +163,12 @@ class ModelValidator:
159
  iou_thresholds = torch.arange(0.5, 1.0, 0.05)
160
  map_all = []
161
  self.progress.start_one_epoch(len(dataloader))
162
- for data, targets in dataloader:
163
- data, targets = data.to(self.device), targets.to(self.device)
164
  with torch.no_grad():
165
- raw_output = self.model(data)
166
- nms_out = bbox_nms(raw_output[-1][0], self.nms)
 
167
  for idx, predict in enumerate(nms_out):
168
  map_value = calculate_map(predict, targets[idx], iou_thresholds)
169
  map_all.append(map_value[0])
 
32
  self.num_epochs = cfg.task.epoch
33
 
34
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
35
+ self.validator = ModelValidator(cfg.task.validation, model, vec2box, save_path, device, self.progress)
36
 
37
  if getattr(train_cfg.ema, "enabled", False):
38
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
40
  self.ema = None
41
  self.scaler = GradScaler()
42
 
43
+ def train_one_batch(self, images: Tensor, targets: Tensor):
44
+ images, targets = images.to(self.device), targets.to(self.device)
45
  self.optimizer.zero_grad()
46
 
47
  with autocast():
48
+ predicts = self.model(images)
49
+ aux_predicts = self.vec2box(predicts["AUX"])
50
+ main_predicts = self.vec2box(predicts["Main"])
51
  loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
52
 
53
  self.scaler.scale(loss).backward()
 
60
  self.model.train()
61
  total_loss = 0
62
 
63
+ for images, targets in dataloader:
64
+ loss, loss_each = self.train_one_batch(images, targets)
65
 
66
  total_loss += loss
67
  self.progress.one_batch(loss_each)
 
111
 
112
  def solve(self, dataloader: StreamDataLoader):
113
  logger.info("πŸ‘€ Start Inference!")
114
+ if isinstance(self.model, torch.nn.Module):
115
+ self.model.eval()
116
  try:
117
  for idx, images in enumerate(dataloader):
118
  images = images.to(self.device)
119
  with torch.no_grad():
120
+ predicts = self.model(images)
121
+ predicts = self.vec2box(predicts["Main"])
122
+ nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
123
  draw_bboxes(
124
  images[0],
125
  nms_out[0],
 
142
  self,
143
  validation_cfg: ValidationConfig,
144
  model: YOLO,
145
+ vec2box: Vec2Box,
146
  save_path: str,
147
  device,
148
  # TODO: think Progress?
149
  progress: ProgressTracker,
150
  ):
151
  self.model = model
152
+ self.vec2box = vec2box
153
  self.device = device
154
  self.progress = progress
155
  self.save_path = save_path
156
+
157
  self.nms = validation_cfg.nms
158
 
159
  def solve(self, dataloader):
 
163
  iou_thresholds = torch.arange(0.5, 1.0, 0.05)
164
  map_all = []
165
  self.progress.start_one_epoch(len(dataloader))
166
+ for images, targets in dataloader:
167
+ images, targets = images.to(self.device), targets.to(self.device)
168
  with torch.no_grad():
169
+ predicts = self.model(images)
170
+ predicts = self.vec2box(predicts["Main"])
171
+ nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
172
  for idx, predict in enumerate(nms_out):
173
  map_value = calculate_map(predict, targets[idx], iou_thresholds)
174
  map_all.append(map_value[0])
yolo/utils/bounding_box_utils.py CHANGED
@@ -4,6 +4,7 @@ from typing import List, Tuple
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
 
7
  from torch import Tensor
8
  from torchvision.ops import batched_nms
9
 
@@ -264,6 +265,7 @@ class BoxMatcher:
264
 
265
  class Vec2Box:
266
  def __init__(self, model, image_size, device):
 
267
  dummy_input = torch.zeros(1, 3, *image_size).to(device)
268
  dummy_output = model(dummy_input)
269
  anchors_num = []
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
+ from loguru import logger
8
  from torch import Tensor
9
  from torchvision.ops import batched_nms
10
 
 
265
 
266
  class Vec2Box:
267
  def __init__(self, model, image_size, device):
268
+ logger.info("🧸 Make a dummy test for auto-anchor size")
269
  dummy_input = torch.zeros(1, 3, *image_size).to(device)
270
  dummy_output = model(dummy_input)
271
  anchors_num = []