✨ [Add] New model, anchor2box move into model
Browse files- yolo/config/model/v9-c.yaml +21 -2
- yolo/model/module.py +39 -6
- yolo/model/yolo.py +1 -1
- yolo/tools/format_converters.py +4 -2
- yolo/tools/loss_functions.py +11 -15
- yolo/tools/solver.py +11 -22
- yolo/utils/bounding_box_utils.py +6 -43
yolo/config/model/v9-c.yaml
CHANGED
@@ -121,5 +121,24 @@ model:
|
|
121 |
tags: A5
|
122 |
|
123 |
- MultiheadDetection:
|
124 |
-
source: [A3, A4, A5
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
tags: A5
|
122 |
|
123 |
- MultiheadDetection:
|
124 |
+
source: [A3, A4, A5]
|
125 |
+
tags: aux_head
|
126 |
+
- Anchor2Box:
|
127 |
+
source: aux_head
|
128 |
+
output: True
|
129 |
+
args:
|
130 |
+
reg_max: ${model.anchor.reg_max}
|
131 |
+
strides: ${model.anchor.strides}
|
132 |
+
tags: aux_bbox
|
133 |
+
|
134 |
+
detection:
|
135 |
+
- MultiheadDetection:
|
136 |
+
source: [P3, P4, P5]
|
137 |
+
tags: reg_head
|
138 |
+
- Anchor2Box:
|
139 |
+
source: reg_head
|
140 |
+
output: True
|
141 |
+
args:
|
142 |
+
reg_max: ${model.anchor.reg_max}
|
143 |
+
strides: ${model.anchor.strides}
|
144 |
+
tags: reg_bbox
|
yolo/model/module.py
CHANGED
@@ -2,10 +2,12 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
5 |
from loguru import logger
|
6 |
from torch import Tensor, nn
|
7 |
from torch.nn.common_types import _size_2_t
|
8 |
|
|
|
9 |
from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
|
10 |
|
11 |
|
@@ -56,7 +58,6 @@ class Detection(nn.Module):
|
|
56 |
anchor_channels = 4 * reg_max
|
57 |
|
58 |
first_neck, in_channels = in_channels
|
59 |
-
# TODO: round up head[0] channels or each head?
|
60 |
anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
|
61 |
class_neck = max(first_neck, min(num_classes * 2, 128))
|
62 |
|
@@ -83,18 +84,50 @@ class MultiheadDetection(nn.Module):
|
|
83 |
|
84 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
85 |
super().__init__()
|
86 |
-
# TODO: Refactor these parts
|
87 |
self.heads = nn.ModuleList(
|
88 |
-
[
|
89 |
-
Detection((in_channels[3 * (idx // 3)], in_channel), num_classes, **head_kwargs)
|
90 |
-
for idx, in_channel in enumerate(in_channels)
|
91 |
-
]
|
92 |
)
|
93 |
|
94 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
95 |
return [head(x) for x, head in zip(x_list, self.heads)]
|
96 |
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
# ----------- Backbone Class ----------- #
|
99 |
class RepConv(nn.Module):
|
100 |
"""A convolutional block that combines two convolution layers (kernel and point-wise)."""
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
from loguru import logger
|
7 |
from torch import Tensor, nn
|
8 |
from torch.nn.common_types import _size_2_t
|
9 |
|
10 |
+
from yolo.utils.bounding_box_utils import generate_anchors
|
11 |
from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
|
12 |
|
13 |
|
|
|
58 |
anchor_channels = 4 * reg_max
|
59 |
|
60 |
first_neck, in_channels = in_channels
|
|
|
61 |
anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
|
62 |
class_neck = max(first_neck, min(num_classes * 2, 128))
|
63 |
|
|
|
84 |
|
85 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
86 |
super().__init__()
|
|
|
87 |
self.heads = nn.ModuleList(
|
88 |
+
[Detection((in_channels[0], in_channel), num_classes, **head_kwargs) for in_channel in in_channels]
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
92 |
return [head(x) for x, head in zip(x_list, self.heads)]
|
93 |
|
94 |
|
95 |
+
class Anchor2Box(nn.Module):
|
96 |
+
def __init__(self, reg_max, strides) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.reg_max = reg_max
|
99 |
+
self.strides = strides
|
100 |
+
# TODO: read by cfg!
|
101 |
+
image_size = [640, 640]
|
102 |
+
self.class_num = 80
|
103 |
+
self.anchors, self.scaler = generate_anchors(image_size, self.strides)
|
104 |
+
reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
|
105 |
+
self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
|
106 |
+
self.anchors = nn.Parameter(self.anchors, requires_grad=False)
|
107 |
+
self.scaler = nn.Parameter(self.scaler, requires_grad=False)
|
108 |
+
|
109 |
+
def forward(self, predicts: List[Tensor]) -> Tensor:
|
110 |
+
"""
|
111 |
+
args:
|
112 |
+
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
113 |
+
return:
|
114 |
+
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
115 |
+
"""
|
116 |
+
preds = []
|
117 |
+
for pred in predicts:
|
118 |
+
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
119 |
+
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
120 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
121 |
+
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
122 |
+
|
123 |
+
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
124 |
+
|
125 |
+
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
126 |
+
preds_box = torch.cat([self.anchors - lt, self.anchors + rb], dim=-1)
|
127 |
+
predicts = torch.cat([preds_cls, preds_box], dim=-1)
|
128 |
+
return predicts, preds_anc
|
129 |
+
|
130 |
+
|
131 |
# ----------- Backbone Class ----------- #
|
132 |
class RepConv(nn.Module):
|
133 |
"""A convolutional block that combines two convolution layers (kernel and point-wise)."""
|
yolo/model/yolo.py
CHANGED
@@ -130,7 +130,7 @@ def create_model(cfg: Config) -> YOLO:
|
|
130 |
logger.info("✅ Success load model")
|
131 |
if cfg.weight:
|
132 |
if os.path.exists(cfg.weight):
|
133 |
-
model.model.load_state_dict(torch.load(cfg.weight))
|
134 |
logger.info("✅ Success load model weight")
|
135 |
else:
|
136 |
logger.info(f"🌐 Weight {cfg.weight} not found, try downloading")
|
|
|
130 |
logger.info("✅ Success load model")
|
131 |
if cfg.weight:
|
132 |
if os.path.exists(cfg.weight):
|
133 |
+
model.model.load_state_dict(torch.load(cfg.weight), strict=False)
|
134 |
logger.info("✅ Success load model weight")
|
135 |
else:
|
136 |
logger.info(f"🌐 Weight {cfg.weight} not found, try downloading")
|
yolo/tools/format_converters.py
CHANGED
@@ -17,13 +17,15 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
|
|
17 |
continue
|
18 |
_, _, conv_name, conv_idx, *details = weight_name.split(".")
|
19 |
if conv_name == "cv4" or conv_name == "cv5":
|
20 |
-
|
|
|
|
|
21 |
|
22 |
if conv_name == "cv2" or conv_name == "cv4":
|
23 |
conv_task = "anchor_conv"
|
24 |
if conv_name == "cv3" or conv_name == "cv5":
|
25 |
conv_task = "class_conv"
|
26 |
|
27 |
-
weight_name = ".".join([
|
28 |
new_state_dict[weight_name] = weight_value
|
29 |
return new_state_dict
|
|
|
17 |
continue
|
18 |
_, _, conv_name, conv_idx, *details = weight_name.split(".")
|
19 |
if conv_name == "cv4" or conv_name == "cv5":
|
20 |
+
layer_idx = 39
|
21 |
+
else:
|
22 |
+
layer_idx = 37
|
23 |
|
24 |
if conv_name == "cv2" or conv_name == "cv4":
|
25 |
conv_task = "anchor_conv"
|
26 |
if conv_name == "cv3" or conv_name == "cv5":
|
27 |
conv_task = "class_conv"
|
28 |
|
29 |
+
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
|
30 |
new_state_dict[weight_name] = weight_value
|
31 |
return new_state_dict
|
yolo/tools/loss_functions.py
CHANGED
@@ -8,12 +8,7 @@ from torch import Tensor, nn
|
|
8 |
from torch.nn import BCEWithLogitsLoss
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
-
from yolo.utils.bounding_box_utils import
|
12 |
-
AnchorBoxConverter,
|
13 |
-
BoxMatcher,
|
14 |
-
calculate_iou,
|
15 |
-
generate_anchors,
|
16 |
-
)
|
17 |
from yolo.utils.module_utils import divide_into_chunks
|
18 |
|
19 |
|
@@ -80,14 +75,15 @@ class YOLOLoss:
|
|
80 |
self.strides = cfg.model.anchor.strides
|
81 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
82 |
|
83 |
-
self.anchors, self.scaler = generate_anchors(self.image_size, self.strides
|
|
|
|
|
84 |
|
85 |
self.cls = BCELoss()
|
86 |
self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
|
87 |
self.iou = BoxLoss()
|
88 |
|
89 |
self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
|
90 |
-
self.box_converter = AnchorBoxConverter(cfg.model, self.image_size, device)
|
91 |
|
92 |
def separate_anchor(self, anchors):
|
93 |
"""
|
@@ -97,16 +93,17 @@ class YOLOLoss:
|
|
97 |
anchors_box = anchors_box / self.scaler[None, :, None]
|
98 |
return anchors_cls, anchors_box
|
99 |
|
100 |
-
def __call__(
|
|
|
|
|
101 |
# Batch_Size x (Anchor + Class) x H x W
|
102 |
# TODO: check datatype, why targets has a little bit error with origin version
|
103 |
-
predicts, predicts_anc = self.box_converter(predicts)
|
104 |
|
105 |
# For each predicted targets, assign a best suitable ground truth box.
|
106 |
-
align_targets, valid_masks = self.matcher(targets,
|
107 |
|
108 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
109 |
-
predicts_cls, predicts_bbox = self.separate_anchor(
|
110 |
|
111 |
cls_norm = targets_cls.sum()
|
112 |
box_norm = targets_cls.sum(-1)[valid_masks]
|
@@ -133,9 +130,8 @@ class DualLoss:
|
|
133 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
134 |
|
135 |
# TODO: Need Refactor this region, make it flexible!
|
136 |
-
|
137 |
-
|
138 |
-
main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
|
139 |
|
140 |
loss_dict = {
|
141 |
"BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
|
|
|
8 |
from torch.nn import BCEWithLogitsLoss
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
+
from yolo.utils.bounding_box_utils import BoxMatcher, calculate_iou, generate_anchors
|
|
|
|
|
|
|
|
|
|
|
12 |
from yolo.utils.module_utils import divide_into_chunks
|
13 |
|
14 |
|
|
|
75 |
self.strides = cfg.model.anchor.strides
|
76 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
77 |
|
78 |
+
self.anchors, self.scaler = generate_anchors(self.image_size, self.strides)
|
79 |
+
self.anchors = self.anchors.to(device)
|
80 |
+
self.scaler = self.scaler.to(device)
|
81 |
|
82 |
self.cls = BCELoss()
|
83 |
self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
|
84 |
self.iou = BoxLoss()
|
85 |
|
86 |
self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
|
|
|
87 |
|
88 |
def separate_anchor(self, anchors):
|
89 |
"""
|
|
|
93 |
anchors_box = anchors_box / self.scaler[None, :, None]
|
94 |
return anchors_cls, anchors_box
|
95 |
|
96 |
+
def __call__(
|
97 |
+
self, predicts_box: List[Tensor], predicts_anc: Tensor, targets: Tensor
|
98 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
99 |
# Batch_Size x (Anchor + Class) x H x W
|
100 |
# TODO: check datatype, why targets has a little bit error with origin version
|
|
|
101 |
|
102 |
# For each predicted targets, assign a best suitable ground truth box.
|
103 |
+
align_targets, valid_masks = self.matcher(targets, predicts_box)
|
104 |
|
105 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
106 |
+
predicts_cls, predicts_bbox = self.separate_anchor(predicts_box)
|
107 |
|
108 |
cls_norm = targets_cls.sum()
|
109 |
box_norm = targets_cls.sum(-1)[valid_masks]
|
|
|
130 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
131 |
|
132 |
# TODO: Need Refactor this region, make it flexible!
|
133 |
+
aux_iou, aux_dfl, aux_cls = self.loss(*predicts[0], targets)
|
134 |
+
main_iou, main_dfl, main_cls = self.loss(*predicts[1], targets)
|
|
|
135 |
|
136 |
loss_dict = {
|
137 |
"BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
|
yolo/tools/solver.py
CHANGED
@@ -10,7 +10,7 @@ from yolo.model.yolo import YOLO
|
|
10 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
11 |
from yolo.tools.drawer import draw_bboxes
|
12 |
from yolo.tools.loss_functions import get_loss_function
|
13 |
-
from yolo.utils.bounding_box_utils import
|
14 |
from yolo.utils.logging_utils import ProgressTracker
|
15 |
from yolo.utils.model_utils import (
|
16 |
ExponentialMovingAverage,
|
@@ -30,11 +30,8 @@ class ModelTrainer:
|
|
30 |
self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
|
31 |
self.num_epochs = cfg.task.epoch
|
32 |
|
33 |
-
validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
34 |
-
|
35 |
-
self.validator = ModelValidator(
|
36 |
-
cfg.task.validation, model, save_path, device, self.progress, anchor2box, validation_dataloader
|
37 |
-
)
|
38 |
|
39 |
if getattr(train_cfg.ema, "enabled", False):
|
40 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
@@ -95,7 +92,7 @@ class ModelTrainer:
|
|
95 |
epoch_loss = self.train_one_epoch(dataloader)
|
96 |
self.progress.finish_one_epoch()
|
97 |
|
98 |
-
self.validator.solve()
|
99 |
|
100 |
|
101 |
class ModelTester:
|
@@ -104,7 +101,6 @@ class ModelTester:
|
|
104 |
self.device = device
|
105 |
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
106 |
|
107 |
-
self.anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
|
108 |
self.nms = cfg.task.nms
|
109 |
self.idx2label = cfg.class_list
|
110 |
self.save_path = save_path
|
@@ -117,8 +113,7 @@ class ModelTester:
|
|
117 |
images = images.to(self.device)
|
118 |
with torch.no_grad():
|
119 |
raw_output = self.model(images)
|
120 |
-
|
121 |
-
nms_out = bbox_nms(predict, self.nms)
|
122 |
draw_bboxes(
|
123 |
images[0],
|
124 |
nms_out[0],
|
@@ -144,33 +139,27 @@ class ModelValidator:
|
|
144 |
model: YOLO,
|
145 |
save_path: str,
|
146 |
device,
|
|
|
147 |
progress: ProgressTracker,
|
148 |
-
anchor2box,
|
149 |
-
validation_dataloader,
|
150 |
):
|
151 |
self.model = model
|
152 |
self.device = device
|
153 |
self.progress = progress
|
154 |
self.save_path = save_path
|
155 |
-
|
156 |
-
self.anchor2box = anchor2box
|
157 |
self.nms = validation_cfg.nms
|
158 |
-
self.validdataloader = validation_dataloader
|
159 |
|
160 |
-
def solve(self):
|
161 |
# logger.info("🧪 Start Validation!")
|
162 |
self.model.eval()
|
163 |
-
|
164 |
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
165 |
map_all = []
|
166 |
-
self.progress.start_one_epoch(len(
|
167 |
-
for data, targets in
|
168 |
data, targets = data.to(self.device), targets.to(self.device)
|
169 |
with torch.no_grad():
|
170 |
raw_output = self.model(data)
|
171 |
-
|
172 |
-
|
173 |
-
nms_out = bbox_nms(predict, self.nms)
|
174 |
for idx, predict in enumerate(nms_out):
|
175 |
map_value = calculate_map(predict, targets[idx], iou_thresholds)
|
176 |
map_all.append(map_value[0])
|
|
|
10 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
11 |
from yolo.tools.drawer import draw_bboxes
|
12 |
from yolo.tools.loss_functions import get_loss_function
|
13 |
+
from yolo.utils.bounding_box_utils import bbox_nms, calculate_map
|
14 |
from yolo.utils.logging_utils import ProgressTracker
|
15 |
from yolo.utils.model_utils import (
|
16 |
ExponentialMovingAverage,
|
|
|
30 |
self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
|
31 |
self.num_epochs = cfg.task.epoch
|
32 |
|
33 |
+
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
34 |
+
self.validator = ModelValidator(cfg.task.validation, model, save_path, device, self.progress)
|
|
|
|
|
|
|
35 |
|
36 |
if getattr(train_cfg.ema, "enabled", False):
|
37 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
92 |
epoch_loss = self.train_one_epoch(dataloader)
|
93 |
self.progress.finish_one_epoch()
|
94 |
|
95 |
+
self.validator.solve(self.validation_dataloader)
|
96 |
|
97 |
|
98 |
class ModelTester:
|
|
|
101 |
self.device = device
|
102 |
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
103 |
|
|
|
104 |
self.nms = cfg.task.nms
|
105 |
self.idx2label = cfg.class_list
|
106 |
self.save_path = save_path
|
|
|
113 |
images = images.to(self.device)
|
114 |
with torch.no_grad():
|
115 |
raw_output = self.model(images)
|
116 |
+
nms_out = bbox_nms(raw_output[-1][0], self.nms)
|
|
|
117 |
draw_bboxes(
|
118 |
images[0],
|
119 |
nms_out[0],
|
|
|
139 |
model: YOLO,
|
140 |
save_path: str,
|
141 |
device,
|
142 |
+
# TODO: think Progress?
|
143 |
progress: ProgressTracker,
|
|
|
|
|
144 |
):
|
145 |
self.model = model
|
146 |
self.device = device
|
147 |
self.progress = progress
|
148 |
self.save_path = save_path
|
|
|
|
|
149 |
self.nms = validation_cfg.nms
|
|
|
150 |
|
151 |
+
def solve(self, dataloader):
|
152 |
# logger.info("🧪 Start Validation!")
|
153 |
self.model.eval()
|
154 |
+
# TODO: choice mAP metrics?
|
155 |
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
156 |
map_all = []
|
157 |
+
self.progress.start_one_epoch(len(dataloader))
|
158 |
+
for data, targets in dataloader:
|
159 |
data, targets = data.to(self.device), targets.to(self.device)
|
160 |
with torch.no_grad():
|
161 |
raw_output = self.model(data)
|
162 |
+
nms_out = bbox_nms(raw_output[-1][0], self.nms)
|
|
|
|
|
163 |
for idx, predict in enumerate(nms_out):
|
164 |
map_value = calculate_map(predict, targets[idx], iou_thresholds)
|
165 |
map_all.append(map_value[0])
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -106,16 +106,16 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
|
|
106 |
return bbox.to(dtype=data_type)
|
107 |
|
108 |
|
109 |
-
def generate_anchors(image_size: List[int], strides: List[int]
|
110 |
W, H = image_size
|
111 |
anchors = []
|
112 |
scaler = []
|
113 |
for stride in strides:
|
114 |
anchor_num = W // stride * H // stride
|
115 |
-
scaler.append(torch.full((anchor_num,), stride
|
116 |
shift = stride // 2
|
117 |
-
x = torch.arange(0, W, stride
|
118 |
-
y = torch.arange(0, H, stride
|
119 |
anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
|
120 |
anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
|
121 |
anchors.append(anchor)
|
@@ -124,44 +124,6 @@ def generate_anchors(image_size: List[int], strides: List[int], device):
|
|
124 |
return all_anchors, all_scalers
|
125 |
|
126 |
|
127 |
-
class AnchorBoxConverter:
|
128 |
-
def __init__(self, model_cfg: ModelConfig, image_size: List[int], device: torch.device) -> None:
|
129 |
-
self.reg_max = model_cfg.anchor.reg_max
|
130 |
-
self.class_num = model_cfg.class_num
|
131 |
-
self.strides = model_cfg.anchor.strides
|
132 |
-
|
133 |
-
self.anchors, self.scaler = generate_anchors(image_size, self.strides, device)
|
134 |
-
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
135 |
-
|
136 |
-
def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
|
137 |
-
"""
|
138 |
-
args:
|
139 |
-
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
140 |
-
return:
|
141 |
-
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
142 |
-
"""
|
143 |
-
preds = []
|
144 |
-
for pred in predicts:
|
145 |
-
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
146 |
-
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
147 |
-
|
148 |
-
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
149 |
-
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
150 |
-
if with_logits:
|
151 |
-
preds_cls = preds_cls.sigmoid()
|
152 |
-
|
153 |
-
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
154 |
-
|
155 |
-
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
156 |
-
pred_minXY = self.anchors - lt
|
157 |
-
pred_maxXY = self.anchors + rb
|
158 |
-
preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1)
|
159 |
-
|
160 |
-
predicts = torch.cat([preds_cls, preds_box], dim=-1)
|
161 |
-
|
162 |
-
return predicts, preds_anc
|
163 |
-
|
164 |
-
|
165 |
class BoxMatcher:
|
166 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
167 |
self.class_num = class_num
|
@@ -291,7 +253,8 @@ class BoxMatcher:
|
|
291 |
|
292 |
def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
293 |
# TODO change function to class or set 80 to class_num instead of a number
|
294 |
-
cls_dist, bbox =
|
|
|
295 |
|
296 |
# filter class by confidence
|
297 |
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|
|
|
106 |
return bbox.to(dtype=data_type)
|
107 |
|
108 |
|
109 |
+
def generate_anchors(image_size: List[int], strides: List[int]):
|
110 |
W, H = image_size
|
111 |
anchors = []
|
112 |
scaler = []
|
113 |
for stride in strides:
|
114 |
anchor_num = W // stride * H // stride
|
115 |
+
scaler.append(torch.full((anchor_num,), stride))
|
116 |
shift = stride // 2
|
117 |
+
x = torch.arange(0, W, stride) + shift
|
118 |
+
y = torch.arange(0, H, stride) + shift
|
119 |
anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
|
120 |
anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
|
121 |
anchors.append(anchor)
|
|
|
124 |
return all_anchors, all_scalers
|
125 |
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
class BoxMatcher:
|
128 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
129 |
self.class_num = class_num
|
|
|
253 |
|
254 |
def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
255 |
# TODO change function to class or set 80 to class_num instead of a number
|
256 |
+
cls_dist, bbox = torch.split(predicts, [80, 4], dim=-1)
|
257 |
+
cls_dist = cls_dist.sigmoid()
|
258 |
|
259 |
# filter class by confidence
|
260 |
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|