henry000 commited on
Commit
86ef0ef
·
1 Parent(s): 2c1f270

✨ [Add] New model, anchor2box move into model

Browse files
yolo/config/model/v9-c.yaml CHANGED
@@ -121,5 +121,24 @@ model:
121
  tags: A5
122
 
123
  - MultiheadDetection:
124
- source: [A3, A4, A5, P3, P4, P5]
125
- output: True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  tags: A5
122
 
123
  - MultiheadDetection:
124
+ source: [A3, A4, A5]
125
+ tags: aux_head
126
+ - Anchor2Box:
127
+ source: aux_head
128
+ output: True
129
+ args:
130
+ reg_max: ${model.anchor.reg_max}
131
+ strides: ${model.anchor.strides}
132
+ tags: aux_bbox
133
+
134
+ detection:
135
+ - MultiheadDetection:
136
+ source: [P3, P4, P5]
137
+ tags: reg_head
138
+ - Anchor2Box:
139
+ source: reg_head
140
+ output: True
141
+ args:
142
+ reg_max: ${model.anchor.reg_max}
143
+ strides: ${model.anchor.strides}
144
+ tags: reg_bbox
yolo/model/module.py CHANGED
@@ -2,10 +2,12 @@ from typing import Any, Dict, List, Optional, Tuple
2
 
3
  import torch
4
  import torch.nn.functional as F
 
5
  from loguru import logger
6
  from torch import Tensor, nn
7
  from torch.nn.common_types import _size_2_t
8
 
 
9
  from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
10
 
11
 
@@ -56,7 +58,6 @@ class Detection(nn.Module):
56
  anchor_channels = 4 * reg_max
57
 
58
  first_neck, in_channels = in_channels
59
- # TODO: round up head[0] channels or each head?
60
  anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
61
  class_neck = max(first_neck, min(num_classes * 2, 128))
62
 
@@ -83,18 +84,50 @@ class MultiheadDetection(nn.Module):
83
 
84
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
85
  super().__init__()
86
- # TODO: Refactor these parts
87
  self.heads = nn.ModuleList(
88
- [
89
- Detection((in_channels[3 * (idx // 3)], in_channel), num_classes, **head_kwargs)
90
- for idx, in_channel in enumerate(in_channels)
91
- ]
92
  )
93
 
94
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
95
  return [head(x) for x, head in zip(x_list, self.heads)]
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # ----------- Backbone Class ----------- #
99
  class RepConv(nn.Module):
100
  """A convolutional block that combines two convolution layers (kernel and point-wise)."""
 
2
 
3
  import torch
4
  import torch.nn.functional as F
5
+ from einops import rearrange
6
  from loguru import logger
7
  from torch import Tensor, nn
8
  from torch.nn.common_types import _size_2_t
9
 
10
+ from yolo.utils.bounding_box_utils import generate_anchors
11
  from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
12
 
13
 
 
58
  anchor_channels = 4 * reg_max
59
 
60
  first_neck, in_channels = in_channels
 
61
  anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
62
  class_neck = max(first_neck, min(num_classes * 2, 128))
63
 
 
84
 
85
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
86
  super().__init__()
 
87
  self.heads = nn.ModuleList(
88
+ [Detection((in_channels[0], in_channel), num_classes, **head_kwargs) for in_channel in in_channels]
 
 
 
89
  )
90
 
91
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
92
  return [head(x) for x, head in zip(x_list, self.heads)]
93
 
94
 
95
+ class Anchor2Box(nn.Module):
96
+ def __init__(self, reg_max, strides) -> None:
97
+ super().__init__()
98
+ self.reg_max = reg_max
99
+ self.strides = strides
100
+ # TODO: read by cfg!
101
+ image_size = [640, 640]
102
+ self.class_num = 80
103
+ self.anchors, self.scaler = generate_anchors(image_size, self.strides)
104
+ reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
105
+ self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
106
+ self.anchors = nn.Parameter(self.anchors, requires_grad=False)
107
+ self.scaler = nn.Parameter(self.scaler, requires_grad=False)
108
+
109
+ def forward(self, predicts: List[Tensor]) -> Tensor:
110
+ """
111
+ args:
112
+ [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
113
+ return:
114
+ [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
115
+ """
116
+ preds = []
117
+ for pred in predicts:
118
+ preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
119
+ preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
120
+ preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
121
+ preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
122
+
123
+ pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
124
+
125
+ lt, rb = pred_LTRB.chunk(2, dim=-1)
126
+ preds_box = torch.cat([self.anchors - lt, self.anchors + rb], dim=-1)
127
+ predicts = torch.cat([preds_cls, preds_box], dim=-1)
128
+ return predicts, preds_anc
129
+
130
+
131
  # ----------- Backbone Class ----------- #
132
  class RepConv(nn.Module):
133
  """A convolutional block that combines two convolution layers (kernel and point-wise)."""
yolo/model/yolo.py CHANGED
@@ -130,7 +130,7 @@ def create_model(cfg: Config) -> YOLO:
130
  logger.info("✅ Success load model")
131
  if cfg.weight:
132
  if os.path.exists(cfg.weight):
133
- model.model.load_state_dict(torch.load(cfg.weight))
134
  logger.info("✅ Success load model weight")
135
  else:
136
  logger.info(f"🌐 Weight {cfg.weight} not found, try downloading")
 
130
  logger.info("✅ Success load model")
131
  if cfg.weight:
132
  if os.path.exists(cfg.weight):
133
+ model.model.load_state_dict(torch.load(cfg.weight), strict=False)
134
  logger.info("✅ Success load model weight")
135
  else:
136
  logger.info(f"🌐 Weight {cfg.weight} not found, try downloading")
yolo/tools/format_converters.py CHANGED
@@ -17,13 +17,15 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
17
  continue
18
  _, _, conv_name, conv_idx, *details = weight_name.split(".")
19
  if conv_name == "cv4" or conv_name == "cv5":
20
- conv_idx = str(int(conv_idx) + 3)
 
 
21
 
22
  if conv_name == "cv2" or conv_name == "cv4":
23
  conv_task = "anchor_conv"
24
  if conv_name == "cv3" or conv_name == "cv5":
25
  conv_task = "class_conv"
26
 
27
- weight_name = ".".join(["37", "heads", conv_idx, conv_task, *details])
28
  new_state_dict[weight_name] = weight_value
29
  return new_state_dict
 
17
  continue
18
  _, _, conv_name, conv_idx, *details = weight_name.split(".")
19
  if conv_name == "cv4" or conv_name == "cv5":
20
+ layer_idx = 39
21
+ else:
22
+ layer_idx = 37
23
 
24
  if conv_name == "cv2" or conv_name == "cv4":
25
  conv_task = "anchor_conv"
26
  if conv_name == "cv3" or conv_name == "cv5":
27
  conv_task = "class_conv"
28
 
29
+ weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
30
  new_state_dict[weight_name] = weight_value
31
  return new_state_dict
yolo/tools/loss_functions.py CHANGED
@@ -8,12 +8,7 @@ from torch import Tensor, nn
8
  from torch.nn import BCEWithLogitsLoss
9
 
10
  from yolo.config.config import Config
11
- from yolo.utils.bounding_box_utils import (
12
- AnchorBoxConverter,
13
- BoxMatcher,
14
- calculate_iou,
15
- generate_anchors,
16
- )
17
  from yolo.utils.module_utils import divide_into_chunks
18
 
19
 
@@ -80,14 +75,15 @@ class YOLOLoss:
80
  self.strides = cfg.model.anchor.strides
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
 
83
- self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
 
 
84
 
85
  self.cls = BCELoss()
86
  self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
87
  self.iou = BoxLoss()
88
 
89
  self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
90
- self.box_converter = AnchorBoxConverter(cfg.model, self.image_size, device)
91
 
92
  def separate_anchor(self, anchors):
93
  """
@@ -97,16 +93,17 @@ class YOLOLoss:
97
  anchors_box = anchors_box / self.scaler[None, :, None]
98
  return anchors_cls, anchors_box
99
 
100
- def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
 
 
101
  # Batch_Size x (Anchor + Class) x H x W
102
  # TODO: check datatype, why targets has a little bit error with origin version
103
- predicts, predicts_anc = self.box_converter(predicts)
104
 
105
  # For each predicted targets, assign a best suitable ground truth box.
106
- align_targets, valid_masks = self.matcher(targets, predicts)
107
 
108
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
109
- predicts_cls, predicts_bbox = self.separate_anchor(predicts)
110
 
111
  cls_norm = targets_cls.sum()
112
  box_norm = targets_cls.sum(-1)[valid_masks]
@@ -133,9 +130,8 @@ class DualLoss:
133
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
134
 
135
  # TODO: Need Refactor this region, make it flexible!
136
- predicts = divide_into_chunks(predicts[0], 2)
137
- aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
138
- main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
139
 
140
  loss_dict = {
141
  "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
 
8
  from torch.nn import BCEWithLogitsLoss
9
 
10
  from yolo.config.config import Config
11
+ from yolo.utils.bounding_box_utils import BoxMatcher, calculate_iou, generate_anchors
 
 
 
 
 
12
  from yolo.utils.module_utils import divide_into_chunks
13
 
14
 
 
75
  self.strides = cfg.model.anchor.strides
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
 
78
+ self.anchors, self.scaler = generate_anchors(self.image_size, self.strides)
79
+ self.anchors = self.anchors.to(device)
80
+ self.scaler = self.scaler.to(device)
81
 
82
  self.cls = BCELoss()
83
  self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
84
  self.iou = BoxLoss()
85
 
86
  self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
 
87
 
88
  def separate_anchor(self, anchors):
89
  """
 
93
  anchors_box = anchors_box / self.scaler[None, :, None]
94
  return anchors_cls, anchors_box
95
 
96
+ def __call__(
97
+ self, predicts_box: List[Tensor], predicts_anc: Tensor, targets: Tensor
98
+ ) -> Tuple[Tensor, Tensor, Tensor]:
99
  # Batch_Size x (Anchor + Class) x H x W
100
  # TODO: check datatype, why targets has a little bit error with origin version
 
101
 
102
  # For each predicted targets, assign a best suitable ground truth box.
103
+ align_targets, valid_masks = self.matcher(targets, predicts_box)
104
 
105
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
106
+ predicts_cls, predicts_bbox = self.separate_anchor(predicts_box)
107
 
108
  cls_norm = targets_cls.sum()
109
  box_norm = targets_cls.sum(-1)[valid_masks]
 
130
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
131
 
132
  # TODO: Need Refactor this region, make it flexible!
133
+ aux_iou, aux_dfl, aux_cls = self.loss(*predicts[0], targets)
134
+ main_iou, main_dfl, main_cls = self.loss(*predicts[1], targets)
 
135
 
136
  loss_dict = {
137
  "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
yolo/tools/solver.py CHANGED
@@ -10,7 +10,7 @@ from yolo.model.yolo import YOLO
10
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
11
  from yolo.tools.drawer import draw_bboxes
12
  from yolo.tools.loss_functions import get_loss_function
13
- from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms, calculate_map
14
  from yolo.utils.logging_utils import ProgressTracker
15
  from yolo.utils.model_utils import (
16
  ExponentialMovingAverage,
@@ -30,11 +30,8 @@ class ModelTrainer:
30
  self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
31
  self.num_epochs = cfg.task.epoch
32
 
33
- validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
34
- anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
35
- self.validator = ModelValidator(
36
- cfg.task.validation, model, save_path, device, self.progress, anchor2box, validation_dataloader
37
- )
38
 
39
  if getattr(train_cfg.ema, "enabled", False):
40
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
@@ -95,7 +92,7 @@ class ModelTrainer:
95
  epoch_loss = self.train_one_epoch(dataloader)
96
  self.progress.finish_one_epoch()
97
 
98
- self.validator.solve()
99
 
100
 
101
  class ModelTester:
@@ -104,7 +101,6 @@ class ModelTester:
104
  self.device = device
105
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
106
 
107
- self.anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
108
  self.nms = cfg.task.nms
109
  self.idx2label = cfg.class_list
110
  self.save_path = save_path
@@ -117,8 +113,7 @@ class ModelTester:
117
  images = images.to(self.device)
118
  with torch.no_grad():
119
  raw_output = self.model(images)
120
- predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
121
- nms_out = bbox_nms(predict, self.nms)
122
  draw_bboxes(
123
  images[0],
124
  nms_out[0],
@@ -144,33 +139,27 @@ class ModelValidator:
144
  model: YOLO,
145
  save_path: str,
146
  device,
 
147
  progress: ProgressTracker,
148
- anchor2box,
149
- validation_dataloader,
150
  ):
151
  self.model = model
152
  self.device = device
153
  self.progress = progress
154
  self.save_path = save_path
155
-
156
- self.anchor2box = anchor2box
157
  self.nms = validation_cfg.nms
158
- self.validdataloader = validation_dataloader
159
 
160
- def solve(self):
161
  # logger.info("🧪 Start Validation!")
162
  self.model.eval()
163
-
164
  iou_thresholds = torch.arange(0.5, 1.0, 0.05)
165
  map_all = []
166
- self.progress.start_one_epoch(len(self.validdataloader))
167
- for data, targets in self.validdataloader:
168
  data, targets = data.to(self.device), targets.to(self.device)
169
  with torch.no_grad():
170
  raw_output = self.model(data)
171
- predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
172
-
173
- nms_out = bbox_nms(predict, self.nms)
174
  for idx, predict in enumerate(nms_out):
175
  map_value = calculate_map(predict, targets[idx], iou_thresholds)
176
  map_all.append(map_value[0])
 
10
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
11
  from yolo.tools.drawer import draw_bboxes
12
  from yolo.tools.loss_functions import get_loss_function
13
+ from yolo.utils.bounding_box_utils import bbox_nms, calculate_map
14
  from yolo.utils.logging_utils import ProgressTracker
15
  from yolo.utils.model_utils import (
16
  ExponentialMovingAverage,
 
30
  self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
31
  self.num_epochs = cfg.task.epoch
32
 
33
+ self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
34
+ self.validator = ModelValidator(cfg.task.validation, model, save_path, device, self.progress)
 
 
 
35
 
36
  if getattr(train_cfg.ema, "enabled", False):
37
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
92
  epoch_loss = self.train_one_epoch(dataloader)
93
  self.progress.finish_one_epoch()
94
 
95
+ self.validator.solve(self.validation_dataloader)
96
 
97
 
98
  class ModelTester:
 
101
  self.device = device
102
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
103
 
 
104
  self.nms = cfg.task.nms
105
  self.idx2label = cfg.class_list
106
  self.save_path = save_path
 
113
  images = images.to(self.device)
114
  with torch.no_grad():
115
  raw_output = self.model(images)
116
+ nms_out = bbox_nms(raw_output[-1][0], self.nms)
 
117
  draw_bboxes(
118
  images[0],
119
  nms_out[0],
 
139
  model: YOLO,
140
  save_path: str,
141
  device,
142
+ # TODO: think Progress?
143
  progress: ProgressTracker,
 
 
144
  ):
145
  self.model = model
146
  self.device = device
147
  self.progress = progress
148
  self.save_path = save_path
 
 
149
  self.nms = validation_cfg.nms
 
150
 
151
+ def solve(self, dataloader):
152
  # logger.info("🧪 Start Validation!")
153
  self.model.eval()
154
+ # TODO: choice mAP metrics?
155
  iou_thresholds = torch.arange(0.5, 1.0, 0.05)
156
  map_all = []
157
+ self.progress.start_one_epoch(len(dataloader))
158
+ for data, targets in dataloader:
159
  data, targets = data.to(self.device), targets.to(self.device)
160
  with torch.no_grad():
161
  raw_output = self.model(data)
162
+ nms_out = bbox_nms(raw_output[-1][0], self.nms)
 
 
163
  for idx, predict in enumerate(nms_out):
164
  map_value = calculate_map(predict, targets[idx], iou_thresholds)
165
  map_all.append(map_value[0])
yolo/utils/bounding_box_utils.py CHANGED
@@ -106,16 +106,16 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
106
  return bbox.to(dtype=data_type)
107
 
108
 
109
- def generate_anchors(image_size: List[int], strides: List[int], device):
110
  W, H = image_size
111
  anchors = []
112
  scaler = []
113
  for stride in strides:
114
  anchor_num = W // stride * H // stride
115
- scaler.append(torch.full((anchor_num,), stride, device=device))
116
  shift = stride // 2
117
- x = torch.arange(0, W, stride, device=device) + shift
118
- y = torch.arange(0, H, stride, device=device) + shift
119
  anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
120
  anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
121
  anchors.append(anchor)
@@ -124,44 +124,6 @@ def generate_anchors(image_size: List[int], strides: List[int], device):
124
  return all_anchors, all_scalers
125
 
126
 
127
- class AnchorBoxConverter:
128
- def __init__(self, model_cfg: ModelConfig, image_size: List[int], device: torch.device) -> None:
129
- self.reg_max = model_cfg.anchor.reg_max
130
- self.class_num = model_cfg.class_num
131
- self.strides = model_cfg.anchor.strides
132
-
133
- self.anchors, self.scaler = generate_anchors(image_size, self.strides, device)
134
- self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
135
-
136
- def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
137
- """
138
- args:
139
- [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
140
- return:
141
- [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
142
- """
143
- preds = []
144
- for pred in predicts:
145
- preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
146
- preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
147
-
148
- preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
149
- preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
150
- if with_logits:
151
- preds_cls = preds_cls.sigmoid()
152
-
153
- pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
154
-
155
- lt, rb = pred_LTRB.chunk(2, dim=-1)
156
- pred_minXY = self.anchors - lt
157
- pred_maxXY = self.anchors + rb
158
- preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1)
159
-
160
- predicts = torch.cat([preds_cls, preds_box], dim=-1)
161
-
162
- return predicts, preds_anc
163
-
164
-
165
  class BoxMatcher:
166
  def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
167
  self.class_num = class_num
@@ -291,7 +253,8 @@ class BoxMatcher:
291
 
292
  def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
293
  # TODO change function to class or set 80 to class_num instead of a number
294
- cls_dist, bbox = predicts.split([80, 4], dim=-1)
 
295
 
296
  # filter class by confidence
297
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
 
106
  return bbox.to(dtype=data_type)
107
 
108
 
109
+ def generate_anchors(image_size: List[int], strides: List[int]):
110
  W, H = image_size
111
  anchors = []
112
  scaler = []
113
  for stride in strides:
114
  anchor_num = W // stride * H // stride
115
+ scaler.append(torch.full((anchor_num,), stride))
116
  shift = stride // 2
117
+ x = torch.arange(0, W, stride) + shift
118
+ y = torch.arange(0, H, stride) + shift
119
  anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
120
  anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
121
  anchors.append(anchor)
 
124
  return all_anchors, all_scalers
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  class BoxMatcher:
128
  def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
129
  self.class_num = class_num
 
253
 
254
  def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
255
  # TODO change function to class or set 80 to class_num instead of a number
256
+ cls_dist, bbox = torch.split(predicts, [80, 4], dim=-1)
257
+ cls_dist = cls_dist.sigmoid()
258
 
259
  # filter class by confidence
260
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)