henry000 commited on
Commit
2dd2ae5
·
1 Parent(s): 0a3c9de

✨ [New] Framework anc2box -> anc2vec + vec2box

Browse files
yolo/config/model/v9-c.yaml CHANGED
@@ -1,6 +1,5 @@
1
  anchor:
2
  reg_max: 16
3
- strides: [8, 16, 32]
4
 
5
  model:
6
  backbone:
@@ -120,23 +119,15 @@ model:
120
 
121
  - MultiheadDetection:
122
  source: [A3, A4, A5]
123
- tags: aux_head
124
- - Anchor2Box:
125
- source: aux_head
126
- output: True
127
  args:
128
  reg_max: ${model.anchor.reg_max}
129
- strides: ${model.anchor.strides}
130
- tags: aux_bbox
131
 
132
  detection:
133
  - MultiheadDetection:
134
  source: [P3, P4, P5]
135
- tags: reg_head
136
- - Anchor2Box:
137
- source: reg_head
138
- output: True
139
  args:
140
  reg_max: ${model.anchor.reg_max}
141
- strides: ${model.anchor.strides}
142
- tags: reg_bbox
 
1
  anchor:
2
  reg_max: 16
 
3
 
4
  model:
5
  backbone:
 
119
 
120
  - MultiheadDetection:
121
  source: [A3, A4, A5]
122
+ tags: AUX
 
 
 
123
  args:
124
  reg_max: ${model.anchor.reg_max}
125
+ output: True
 
126
 
127
  detection:
128
  - MultiheadDetection:
129
  source: [P3, P4, P5]
130
+ tags: Main
 
 
 
131
  args:
132
  reg_max: ${model.anchor.reg_max}
133
+ output: True
 
yolo/lazy.py CHANGED
@@ -11,6 +11,7 @@ from yolo.config.config import Config
11
  from yolo.model.yolo import create_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
 
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
16
 
@@ -27,12 +28,14 @@ def main(cfg: Config):
27
  else:
28
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
29
 
 
 
30
  if cfg.task.task == "train":
31
- trainer = ModelTrainer(cfg, model, save_path, device)
32
  trainer.solve(dataloader)
33
 
34
  if cfg.task.task == "inference":
35
- tester = ModelTester(cfg, model, save_path, device)
36
  tester.solve(dataloader)
37
 
38
 
 
11
  from yolo.model.yolo import create_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
+ from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
17
 
 
28
  else:
29
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
30
 
31
+ vec2box = Vec2Box(model, cfg.image_size, device)
32
+
33
  if cfg.task.task == "train":
34
+ trainer = ModelTrainer(cfg, model, vec2box, save_path, device)
35
  trainer.solve(dataloader)
36
 
37
  if cfg.task.task == "inference":
38
+ tester = ModelTester(cfg, model, vec2box, save_path, device)
39
  tester.solve(dataloader)
40
 
41
 
yolo/model/module.py CHANGED
@@ -58,7 +58,7 @@ class Detection(nn.Module):
58
  anchor_channels = 4 * reg_max
59
 
60
  first_neck, in_channels = in_channels
61
- anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
62
  class_neck = max(first_neck, min(num_classes * 2, 128))
63
 
64
  self.anchor_conv = nn.Sequential(
@@ -70,13 +70,16 @@ class Detection(nn.Module):
70
  Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
71
  )
72
 
 
 
73
  self.anchor_conv[-1].bias.data.fill_(1.0)
74
  self.class_conv[-1].bias.data.fill_(-10)
75
 
76
- def forward(self, x: List[Tensor]) -> List[Tensor]:
77
  anchor_x = self.anchor_conv(x)
78
  class_x = self.class_conv(x)
79
- return torch.cat([anchor_x, class_x], dim=1)
 
80
 
81
 
82
  class MultiheadDetection(nn.Module):
@@ -92,40 +95,18 @@ class MultiheadDetection(nn.Module):
92
  return [head(x) for x, head in zip(x_list, self.heads)]
93
 
94
 
95
- class Anchor2Box(nn.Module):
96
- def __init__(self, reg_max, strides, num_classes: int) -> None:
97
  super().__init__()
98
- self.reg_max = reg_max
99
- self.strides = strides
100
- # TODO: read by cfg!
101
- image_size = [640, 640]
102
- self.num_classes = num_classes
103
- self.anchors, self.scaler = generate_anchors(image_size, self.strides)
104
- reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
105
- self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
106
- self.anchors = nn.Parameter(self.anchors, requires_grad=False)
107
- self.scaler = nn.Parameter(self.scaler, requires_grad=False)
108
-
109
- def forward(self, predicts: List[Tensor]) -> Tensor:
110
- """
111
- args:
112
- [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
113
- return:
114
- [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
115
- """
116
- preds = []
117
- for pred in predicts:
118
- preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
119
- preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
120
- preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.num_classes), dim=-1)
121
- preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
122
-
123
- pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
124
-
125
- lt, rb = pred_LTRB.chunk(2, dim=-1)
126
- preds_box = torch.cat([self.anchors - lt, self.anchors + rb], dim=-1)
127
- predicts = torch.cat([preds_cls, preds_box], dim=-1)
128
- return predicts, preds_anc
129
 
130
 
131
  # ----------- Backbone Class ----------- #
 
58
  anchor_channels = 4 * reg_max
59
 
60
  first_neck, in_channels = in_channels
61
+ anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, reg_max)
62
  class_neck = max(first_neck, min(num_classes * 2, 128))
63
 
64
  self.anchor_conv = nn.Sequential(
 
70
  Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
71
  )
72
 
73
+ self.anc2vec = Anchor2Vec(reg_max=reg_max)
74
+
75
  self.anchor_conv[-1].bias.data.fill_(1.0)
76
  self.class_conv[-1].bias.data.fill_(-10)
77
 
78
+ def forward(self, x: Tensor) -> Tuple[Tensor]:
79
  anchor_x = self.anchor_conv(x)
80
  class_x = self.class_conv(x)
81
+ anchor_x, vector_x = self.anc2vec(anchor_x)
82
+ return class_x, anchor_x, vector_x
83
 
84
 
85
  class MultiheadDetection(nn.Module):
 
95
  return [head(x) for x, head in zip(x_list, self.heads)]
96
 
97
 
98
+ class Anchor2Vec(nn.Module):
99
+ def __init__(self, reg_max: int = 16) -> None:
100
  super().__init__()
101
+ reverse_reg = torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1, 1)
102
+ self.anc2vec = nn.Conv3d(in_channels=reg_max, out_channels=1, kernel_size=1, bias=False)
103
+ self.anc2vec.weight = nn.Parameter(reverse_reg, requires_grad=False)
104
+
105
+ def forward(self, anchor_x: Tensor) -> Tensor:
106
+ anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4)
107
+ vector_x = anchor_x.softmax(dim=1)
108
+ vector_x = self.anc2vec(vector_x).squeeze(1)
109
+ return anchor_x, vector_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  # ----------- Backbone Class ----------- #
yolo/model/yolo.py CHANGED
@@ -66,7 +66,7 @@ class YOLO(nn.Module):
66
 
67
  def forward(self, x):
68
  y = {0: x}
69
- output = []
70
  for index, layer in enumerate(self.model, start=1):
71
  if isinstance(layer.source, list):
72
  model_input = [y[idx] for idx in layer.source]
@@ -77,7 +77,7 @@ class YOLO(nn.Module):
77
  if layer.usable:
78
  y[index] = x
79
  if layer.output:
80
- output.append(x)
81
  return output
82
 
83
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
@@ -131,7 +131,7 @@ def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str =
131
  logger.info("✅ Success load model")
132
  if weight_path:
133
  if os.path.exists(weight_path):
134
- model.model.load_state_dict(torch.load(weight_path), strict=False)
135
  logger.info("✅ Success load model weight")
136
  else:
137
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
 
66
 
67
  def forward(self, x):
68
  y = {0: x}
69
+ output = dict()
70
  for index, layer in enumerate(self.model, start=1):
71
  if isinstance(layer.source, list):
72
  model_input = [y[idx] for idx in layer.source]
 
77
  if layer.usable:
78
  y[index] = x
79
  if layer.output:
80
+ output[layer.tags] = x
81
  return output
82
 
83
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
 
131
  logger.info("✅ Success load model")
132
  if weight_path:
133
  if os.path.exists(weight_path):
134
+ model.model.load_state_dict(torch.load(weight_path))
135
  logger.info("✅ Success load model weight")
136
  else:
137
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
yolo/tools/format_converters.py CHANGED
@@ -17,7 +17,7 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
17
  continue
18
  _, _, conv_name, conv_idx, *details = weight_name.split(".")
19
  if conv_name == "cv4" or conv_name == "cv5":
20
- layer_idx = 39
21
  else:
22
  layer_idx = 37
23
 
 
17
  continue
18
  _, _, conv_name, conv_idx, *details = weight_name.split(".")
19
  if conv_name == "cv4" or conv_name == "cv5":
20
+ layer_idx = 38
21
  else:
22
  layer_idx = 37
23
 
yolo/tools/loss_functions.py CHANGED
@@ -2,14 +2,12 @@ from typing import Any, Dict, List, Tuple
2
 
3
  import torch
4
  import torch.nn.functional as F
5
- from einops import rearrange
6
  from loguru import logger
7
  from torch import Tensor, nn
8
  from torch.nn import BCEWithLogitsLoss
9
 
10
- from yolo.config.config import Config
11
- from yolo.utils.bounding_box_utils import BoxMatcher, calculate_iou, generate_anchors
12
- from yolo.utils.module_utils import divide_into_chunks
13
 
14
 
15
  class BCELoss(nn.Module):
@@ -40,10 +38,9 @@ class BoxLoss(nn.Module):
40
 
41
 
42
  class DFLoss(nn.Module):
43
- def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
44
  super().__init__()
45
- self.anchors = anchors
46
- self.scaler = scaler
47
  self.reg_max = reg_max
48
 
49
  def forward(
@@ -51,8 +48,9 @@ class DFLoss(nn.Module):
51
  ) -> Any:
52
  valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
53
  bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
54
- anchors_norm = (self.anchors / self.scaler[:, None])[None]
55
- targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
 
56
  picked_targets = targets_dist[valid_bbox].view(-1)
57
  picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
58
 
@@ -68,42 +66,31 @@ class DFLoss(nn.Module):
68
 
69
 
70
  class YOLOLoss:
71
- def __init__(self, cfg: Config) -> None:
72
- self.reg_max = cfg.model.anchor.reg_max
73
- self.class_num = cfg.class_num
74
- self.image_size = list(cfg.image_size)
75
- self.strides = cfg.model.anchor.strides
76
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
-
78
- self.anchors, self.scaler = generate_anchors(self.image_size, self.strides)
79
- self.anchors = self.anchors.to(device)
80
- self.scaler = self.scaler.to(device)
81
 
82
  self.cls = BCELoss()
83
- self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
84
  self.iou = BoxLoss()
85
 
86
- self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
87
 
88
  def separate_anchor(self, anchors):
89
  """
90
  separate anchor and bbouding box
91
  """
92
  anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
93
- anchors_box = anchors_box / self.scaler[None, :, None]
94
  return anchors_cls, anchors_box
95
 
96
- def __call__(
97
- self, predicts_box: List[Tensor], predicts_anc: Tensor, targets: Tensor
98
- ) -> Tuple[Tensor, Tensor, Tensor]:
99
- # Batch_Size x (Anchor + Class) x H x W
100
- # TODO: check datatype, why targets has a little bit error with origin version
101
-
102
  # For each predicted targets, assign a best suitable ground truth box.
103
- align_targets, valid_masks = self.matcher(targets, predicts_box)
104
 
105
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
106
- predicts_cls, predicts_bbox = self.separate_anchor(predicts_box)
107
 
108
  cls_norm = targets_cls.sum()
109
  box_norm = targets_cls.sum(-1)[valid_masks]
@@ -111,7 +98,7 @@ class YOLOLoss:
111
  ## -- CLS -- ##
112
  loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
113
  ## -- IOU -- ##
114
- loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
115
  ## -- DFL -- ##
116
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
117
 
@@ -119,19 +106,22 @@ class YOLOLoss:
119
 
120
 
121
  class DualLoss:
122
- def __init__(self, cfg: Config) -> None:
123
- self.loss = YOLOLoss(cfg)
124
- self.aux_rate = cfg.task.loss.aux
125
 
126
- self.iou_rate = cfg.task.loss.objective["BoxLoss"]
127
- self.dfl_rate = cfg.task.loss.objective["DFLoss"]
128
- self.cls_rate = cfg.task.loss.objective["BCELoss"]
129
 
130
- def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
 
 
131
 
 
 
 
132
  # TODO: Need Refactor this region, make it flexible!
133
- aux_iou, aux_dfl, aux_cls = self.loss(*predicts[0], targets)
134
- main_iou, main_dfl, main_cls = self.loss(*predicts[1], targets)
135
 
136
  loss_dict = {
137
  "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
@@ -142,7 +132,7 @@ class DualLoss:
142
  return loss_sum, loss_dict
143
 
144
 
145
- def get_loss_function(cfg: Config) -> YOLOLoss:
146
- loss_function = DualLoss(cfg)
147
  logger.info("✅ Success load loss function")
148
  return loss_function
 
2
 
3
  import torch
4
  import torch.nn.functional as F
 
5
  from loguru import logger
6
  from torch import Tensor, nn
7
  from torch.nn import BCEWithLogitsLoss
8
 
9
+ from yolo.config.config import Config, LossConfig
10
+ from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
 
11
 
12
 
13
  class BCELoss(nn.Module):
 
38
 
39
 
40
  class DFLoss(nn.Module):
41
+ def __init__(self, anchors_norm: Tensor, reg_max: int) -> None:
42
  super().__init__()
43
+ self.anchors_norm = anchors_norm
 
44
  self.reg_max = reg_max
45
 
46
  def forward(
 
48
  ) -> Any:
49
  valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
50
  bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
51
+ targets_dist = torch.cat(((self.anchors_norm - bbox_lt), (bbox_rb - self.anchors_norm)), -1).clamp(
52
+ 0, self.reg_max - 1.01
53
+ )
54
  picked_targets = targets_dist[valid_bbox].view(-1)
55
  picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
56
 
 
66
 
67
 
68
  class YOLOLoss:
69
+ def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80, reg_max: int = 16) -> None:
70
+ self.class_num = class_num
71
+ self.vec2box = vec2box
 
 
 
 
 
 
 
72
 
73
  self.cls = BCELoss()
74
+ self.dfl = DFLoss(vec2box.anchor_norm, reg_max)
75
  self.iou = BoxLoss()
76
 
77
+ self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
78
 
79
  def separate_anchor(self, anchors):
80
  """
81
  separate anchor and bbouding box
82
  """
83
  anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
84
+ anchors_box = anchors_box / self.vec2box.scaler[None, :, None]
85
  return anchors_cls, anchors_box
86
 
87
+ def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
88
+ predicts_cls, predicts_anc, predicts_box = predicts
 
 
 
 
89
  # For each predicted targets, assign a best suitable ground truth box.
90
+ align_targets, valid_masks = self.matcher(targets, (predicts_cls, predicts_box))
91
 
92
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
93
+ predicts_box = predicts_box / self.vec2box.scaler[None, :, None]
94
 
95
  cls_norm = targets_cls.sum()
96
  box_norm = targets_cls.sum(-1)[valid_masks]
 
98
  ## -- CLS -- ##
99
  loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
100
  ## -- IOU -- ##
101
+ loss_iou = self.iou(predicts_box, targets_bbox, valid_masks, box_norm, cls_norm)
102
  ## -- DFL -- ##
103
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
104
 
 
106
 
107
 
108
  class DualLoss:
109
+ def __init__(self, cfg: Config, vec2box) -> None:
110
+ loss_cfg = cfg.task.loss
111
+ self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.class_num, reg_max=cfg.model.anchor.reg_max)
112
 
113
+ self.aux_rate = loss_cfg.aux
 
 
114
 
115
+ self.iou_rate = loss_cfg.objective["BoxLoss"]
116
+ self.dfl_rate = loss_cfg.objective["DFLoss"]
117
+ self.cls_rate = loss_cfg.objective["BCELoss"]
118
 
119
+ def __call__(
120
+ self, aux_predicts: List[Tensor], main_predicts: List[Tensor], targets: Tensor
121
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
122
  # TODO: Need Refactor this region, make it flexible!
123
+ aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
124
+ main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
125
 
126
  loss_dict = {
127
  "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
 
132
  return loss_sum, loss_dict
133
 
134
 
135
+ def get_loss_function(cfg: Config, vec2box) -> DualLoss:
136
+ loss_function = DualLoss(cfg, vec2box)
137
  logger.info("✅ Success load loss function")
138
  return loss_function
yolo/tools/solver.py CHANGED
@@ -10,7 +10,7 @@ from yolo.model.yolo import YOLO
10
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
11
  from yolo.tools.drawer import draw_bboxes
12
  from yolo.tools.loss_functions import get_loss_function
13
- from yolo.utils.bounding_box_utils import bbox_nms, calculate_map
14
  from yolo.utils.logging_utils import ProgressTracker
15
  from yolo.utils.model_utils import (
16
  ExponentialMovingAverage,
@@ -20,13 +20,14 @@ from yolo.utils.model_utils import (
20
 
21
 
22
  class ModelTrainer:
23
- def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
24
  train_cfg: TrainConfig = cfg.task
25
  self.model = model
 
26
  self.device = device
27
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
28
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
29
- self.loss_fn = get_loss_function(cfg)
30
  self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
31
  self.num_epochs = cfg.task.epoch
32
 
@@ -45,7 +46,9 @@ class ModelTrainer:
45
 
46
  with autocast():
47
  outputs = self.model(data)
48
- loss, loss_item = self.loss_fn(outputs, targets)
 
 
49
 
50
  self.scaler.scale(loss).backward()
51
  self.scaler.step(self.optimizer)
@@ -96,9 +99,10 @@ class ModelTrainer:
96
 
97
 
98
  class ModelTester:
99
- def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
100
  self.model = model
101
  self.device = device
 
102
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
103
 
104
  self.nms = cfg.task.nms
@@ -112,8 +116,9 @@ class ModelTester:
112
  for idx, images in enumerate(dataloader):
113
  images = images.to(self.device)
114
  with torch.no_grad():
115
- raw_output = self.model(images)
116
- nms_out = bbox_nms(raw_output[-1][0], self.nms)
 
117
  draw_bboxes(
118
  images[0],
119
  nms_out[0],
 
10
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
11
  from yolo.tools.drawer import draw_bboxes
12
  from yolo.tools.loss_functions import get_loss_function
13
+ from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
14
  from yolo.utils.logging_utils import ProgressTracker
15
  from yolo.utils.model_utils import (
16
  ExponentialMovingAverage,
 
20
 
21
 
22
  class ModelTrainer:
23
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
24
  train_cfg: TrainConfig = cfg.task
25
  self.model = model
26
+ self.vec2box = vec2box
27
  self.device = device
28
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
29
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
30
+ self.loss_fn = get_loss_function(cfg, vec2box)
31
  self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
32
  self.num_epochs = cfg.task.epoch
33
 
 
46
 
47
  with autocast():
48
  outputs = self.model(data)
49
+ aux_predicts = self.vec2box(outputs["AUX"])
50
+ main_predicts = self.vec2box(outputs["Main"])
51
+ loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
52
 
53
  self.scaler.scale(loss).backward()
54
  self.scaler.step(self.optimizer)
 
99
 
100
 
101
  class ModelTester:
102
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
103
  self.model = model
104
  self.device = device
105
+ self.vec2box = vec2box
106
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
107
 
108
  self.nms = cfg.task.nms
 
116
  for idx, images in enumerate(dataloader):
117
  images = images.to(self.device)
118
  with torch.no_grad():
119
+ outputs = self.model(images)
120
+ outputs = self.vec2box(outputs["Main"])
121
+ nms_out = bbox_nms(outputs[0], outputs[2], self.nms)
122
  draw_bboxes(
123
  images[0],
124
  nms_out[0],
yolo/utils/bounding_box_utils.py CHANGED
@@ -106,12 +106,23 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
106
  return bbox.to(dtype=data_type)
107
 
108
 
109
- def generate_anchors(image_size: List[int], strides: List[int]):
 
 
 
 
 
 
 
 
 
 
110
  W, H = image_size
111
  anchors = []
112
  scaler = []
113
- for stride in strides:
114
- anchor_num = W // stride * H // stride
 
115
  scaler.append(torch.full((anchor_num,), stride))
116
  shift = stride // 2
117
  x = torch.arange(0, W, stride) + shift
@@ -207,13 +218,13 @@ class BoxMatcher:
207
  unique_indices = target_matrix.argmax(dim=1)
208
  return unique_indices[..., None]
209
 
210
- def __call__(self, target: Tensor, predict: Tensor) -> Tuple[Tensor, Tensor]:
211
  """
212
  1. For each anchor prediction, find the highest suitability targets
213
  2. Select the targets
214
  2. Noramlize the class probilities of targets
215
  """
216
- predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
217
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
218
  target_cls = target_cls.long().clamp(0)
219
 
@@ -251,9 +262,37 @@ class BoxMatcher:
251
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
252
 
253
 
254
- def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  # TODO change function to class or set 80 to class_num instead of a number
256
- cls_dist, bbox = torch.split(predicts, [80, 4], dim=-1)
257
  cls_dist = cls_dist.sigmoid()
258
 
259
  # filter class by confidence
@@ -266,7 +305,7 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
266
  batch_idx, *_ = torch.where(valid_mask)
267
  nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
268
  predicts_nms = []
269
- for idx in range(predicts.size(0)):
270
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
271
 
272
  predict_nms = torch.cat(
 
106
  return bbox.to(dtype=data_type)
107
 
108
 
109
+ def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
110
+ """
111
+ Find the anchor maps for each w, h.
112
+
113
+ Args:
114
+ anchors_list List[[w1, h1], [w2, h2], ...]: the anchor num for each predicted anchor
115
+
116
+ Returns:
117
+ all_anchors [HW x 2]:
118
+ all_scalers [HW]: The index of the best targets for each anchors
119
+ """
120
  W, H = image_size
121
  anchors = []
122
  scaler = []
123
+ for anchor_wh in anchors_list:
124
+ stride = W // anchor_wh[0]
125
+ anchor_num = anchor_wh[0] * anchor_wh[1]
126
  scaler.append(torch.full((anchor_num,), stride))
127
  shift = stride // 2
128
  x = torch.arange(0, W, stride) + shift
 
218
  unique_indices = target_matrix.argmax(dim=1)
219
  return unique_indices[..., None]
220
 
221
+ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
222
  """
223
  1. For each anchor prediction, find the highest suitability targets
224
  2. Select the targets
225
  2. Noramlize the class probilities of targets
226
  """
227
+ predict_cls, predict_bbox = predict
228
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
229
  target_cls = target_cls.long().clamp(0)
230
 
 
262
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
263
 
264
 
265
+ class Vec2Box:
266
+ def __init__(self, model, image_size, device):
267
+ dummy_input = torch.zeros(1, 3, *image_size).to(device)
268
+ dummy_output = model(dummy_input)
269
+ anchors_num = []
270
+ for predict_head in dummy_output["Main"]:
271
+ _, _, *anchor_num = predict_head[2].shape
272
+ anchors_num.append(anchor_num)
273
+ anchor_grid, scaler = generate_anchors(image_size, anchors_num)
274
+ self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
275
+ self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)
276
+
277
+ def __call__(self, predicts):
278
+ preds_cls, preds_anc, preds_box = [], [], []
279
+ for layer_output in predicts:
280
+ pred_cls, pred_anc, pred_box = layer_output
281
+ preds_cls.append(rearrange(pred_cls, "B C h w -> B (h w) C"))
282
+ preds_anc.append(rearrange(pred_anc, "B A R h w -> B (h w) R A"))
283
+ preds_box.append(rearrange(pred_box, "B X h w -> B (h w) X"))
284
+ preds_cls = torch.concat(preds_cls, dim=1)
285
+ preds_anc = torch.concat(preds_anc, dim=1)
286
+ preds_box = torch.concat(preds_box, dim=1)
287
+
288
+ pred_LTRB = preds_box * self.scaler.view(1, -1, 1)
289
+ lt, rb = pred_LTRB.chunk(2, dim=-1)
290
+ preds_box = torch.cat([self.anchor_grid - lt, self.anchor_grid + rb], dim=-1)
291
+ return preds_cls, preds_anc, preds_box
292
+
293
+
294
+ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig):
295
  # TODO change function to class or set 80 to class_num instead of a number
 
296
  cls_dist = cls_dist.sigmoid()
297
 
298
  # filter class by confidence
 
305
  batch_idx, *_ = torch.where(valid_mask)
306
  nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
307
  predicts_nms = []
308
+ for idx in range(cls_dist.size(0)):
309
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
310
 
311
  predict_nms = torch.cat(