✨ [New] Framework anc2box -> anc2vec + vec2box
Browse files- yolo/config/model/v9-c.yaml +4 -13
- yolo/lazy.py +5 -2
- yolo/model/module.py +17 -36
- yolo/model/yolo.py +3 -3
- yolo/tools/format_converters.py +1 -1
- yolo/tools/loss_functions.py +32 -42
- yolo/tools/solver.py +12 -7
- yolo/utils/bounding_box_utils.py +47 -8
yolo/config/model/v9-c.yaml
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
anchor:
|
2 |
reg_max: 16
|
3 |
-
strides: [8, 16, 32]
|
4 |
|
5 |
model:
|
6 |
backbone:
|
@@ -120,23 +119,15 @@ model:
|
|
120 |
|
121 |
- MultiheadDetection:
|
122 |
source: [A3, A4, A5]
|
123 |
-
tags:
|
124 |
-
- Anchor2Box:
|
125 |
-
source: aux_head
|
126 |
-
output: True
|
127 |
args:
|
128 |
reg_max: ${model.anchor.reg_max}
|
129 |
-
|
130 |
-
tags: aux_bbox
|
131 |
|
132 |
detection:
|
133 |
- MultiheadDetection:
|
134 |
source: [P3, P4, P5]
|
135 |
-
tags:
|
136 |
-
- Anchor2Box:
|
137 |
-
source: reg_head
|
138 |
-
output: True
|
139 |
args:
|
140 |
reg_max: ${model.anchor.reg_max}
|
141 |
-
|
142 |
-
tags: reg_bbox
|
|
|
1 |
anchor:
|
2 |
reg_max: 16
|
|
|
3 |
|
4 |
model:
|
5 |
backbone:
|
|
|
119 |
|
120 |
- MultiheadDetection:
|
121 |
source: [A3, A4, A5]
|
122 |
+
tags: AUX
|
|
|
|
|
|
|
123 |
args:
|
124 |
reg_max: ${model.anchor.reg_max}
|
125 |
+
output: True
|
|
|
126 |
|
127 |
detection:
|
128 |
- MultiheadDetection:
|
129 |
source: [P3, P4, P5]
|
130 |
+
tags: Main
|
|
|
|
|
|
|
131 |
args:
|
132 |
reg_max: ${model.anchor.reg_max}
|
133 |
+
output: True
|
|
yolo/lazy.py
CHANGED
@@ -11,6 +11,7 @@ from yolo.config.config import Config
|
|
11 |
from yolo.model.yolo import create_model
|
12 |
from yolo.tools.data_loader import create_dataloader
|
13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
|
|
14 |
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
16 |
|
@@ -27,12 +28,14 @@ def main(cfg: Config):
|
|
27 |
else:
|
28 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
|
29 |
|
|
|
|
|
30 |
if cfg.task.task == "train":
|
31 |
-
trainer = ModelTrainer(cfg, model, save_path, device)
|
32 |
trainer.solve(dataloader)
|
33 |
|
34 |
if cfg.task.task == "inference":
|
35 |
-
tester = ModelTester(cfg, model, save_path, device)
|
36 |
tester.solve(dataloader)
|
37 |
|
38 |
|
|
|
11 |
from yolo.model.yolo import create_model
|
12 |
from yolo.tools.data_loader import create_dataloader
|
13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
14 |
+
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
17 |
|
|
|
28 |
else:
|
29 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
|
30 |
|
31 |
+
vec2box = Vec2Box(model, cfg.image_size, device)
|
32 |
+
|
33 |
if cfg.task.task == "train":
|
34 |
+
trainer = ModelTrainer(cfg, model, vec2box, save_path, device)
|
35 |
trainer.solve(dataloader)
|
36 |
|
37 |
if cfg.task.task == "inference":
|
38 |
+
tester = ModelTester(cfg, model, vec2box, save_path, device)
|
39 |
tester.solve(dataloader)
|
40 |
|
41 |
|
yolo/model/module.py
CHANGED
@@ -58,7 +58,7 @@ class Detection(nn.Module):
|
|
58 |
anchor_channels = 4 * reg_max
|
59 |
|
60 |
first_neck, in_channels = in_channels
|
61 |
-
anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels,
|
62 |
class_neck = max(first_neck, min(num_classes * 2, 128))
|
63 |
|
64 |
self.anchor_conv = nn.Sequential(
|
@@ -70,13 +70,16 @@ class Detection(nn.Module):
|
|
70 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
71 |
)
|
72 |
|
|
|
|
|
73 |
self.anchor_conv[-1].bias.data.fill_(1.0)
|
74 |
self.class_conv[-1].bias.data.fill_(-10)
|
75 |
|
76 |
-
def forward(self, x:
|
77 |
anchor_x = self.anchor_conv(x)
|
78 |
class_x = self.class_conv(x)
|
79 |
-
|
|
|
80 |
|
81 |
|
82 |
class MultiheadDetection(nn.Module):
|
@@ -92,40 +95,18 @@ class MultiheadDetection(nn.Module):
|
|
92 |
return [head(x) for x, head in zip(x_list, self.heads)]
|
93 |
|
94 |
|
95 |
-
class
|
96 |
-
def __init__(self, reg_max
|
97 |
super().__init__()
|
98 |
-
|
99 |
-
self.
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
self.scaler = nn.Parameter(self.scaler, requires_grad=False)
|
108 |
-
|
109 |
-
def forward(self, predicts: List[Tensor]) -> Tensor:
|
110 |
-
"""
|
111 |
-
args:
|
112 |
-
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
113 |
-
return:
|
114 |
-
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
115 |
-
"""
|
116 |
-
preds = []
|
117 |
-
for pred in predicts:
|
118 |
-
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
119 |
-
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
120 |
-
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.num_classes), dim=-1)
|
121 |
-
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
122 |
-
|
123 |
-
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
124 |
-
|
125 |
-
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
126 |
-
preds_box = torch.cat([self.anchors - lt, self.anchors + rb], dim=-1)
|
127 |
-
predicts = torch.cat([preds_cls, preds_box], dim=-1)
|
128 |
-
return predicts, preds_anc
|
129 |
|
130 |
|
131 |
# ----------- Backbone Class ----------- #
|
|
|
58 |
anchor_channels = 4 * reg_max
|
59 |
|
60 |
first_neck, in_channels = in_channels
|
61 |
+
anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, reg_max)
|
62 |
class_neck = max(first_neck, min(num_classes * 2, 128))
|
63 |
|
64 |
self.anchor_conv = nn.Sequential(
|
|
|
70 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
71 |
)
|
72 |
|
73 |
+
self.anc2vec = Anchor2Vec(reg_max=reg_max)
|
74 |
+
|
75 |
self.anchor_conv[-1].bias.data.fill_(1.0)
|
76 |
self.class_conv[-1].bias.data.fill_(-10)
|
77 |
|
78 |
+
def forward(self, x: Tensor) -> Tuple[Tensor]:
|
79 |
anchor_x = self.anchor_conv(x)
|
80 |
class_x = self.class_conv(x)
|
81 |
+
anchor_x, vector_x = self.anc2vec(anchor_x)
|
82 |
+
return class_x, anchor_x, vector_x
|
83 |
|
84 |
|
85 |
class MultiheadDetection(nn.Module):
|
|
|
95 |
return [head(x) for x, head in zip(x_list, self.heads)]
|
96 |
|
97 |
|
98 |
+
class Anchor2Vec(nn.Module):
|
99 |
+
def __init__(self, reg_max: int = 16) -> None:
|
100 |
super().__init__()
|
101 |
+
reverse_reg = torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1, 1)
|
102 |
+
self.anc2vec = nn.Conv3d(in_channels=reg_max, out_channels=1, kernel_size=1, bias=False)
|
103 |
+
self.anc2vec.weight = nn.Parameter(reverse_reg, requires_grad=False)
|
104 |
+
|
105 |
+
def forward(self, anchor_x: Tensor) -> Tensor:
|
106 |
+
anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4)
|
107 |
+
vector_x = anchor_x.softmax(dim=1)
|
108 |
+
vector_x = self.anc2vec(vector_x).squeeze(1)
|
109 |
+
return anchor_x, vector_x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
|
112 |
# ----------- Backbone Class ----------- #
|
yolo/model/yolo.py
CHANGED
@@ -66,7 +66,7 @@ class YOLO(nn.Module):
|
|
66 |
|
67 |
def forward(self, x):
|
68 |
y = {0: x}
|
69 |
-
output =
|
70 |
for index, layer in enumerate(self.model, start=1):
|
71 |
if isinstance(layer.source, list):
|
72 |
model_input = [y[idx] for idx in layer.source]
|
@@ -77,7 +77,7 @@ class YOLO(nn.Module):
|
|
77 |
if layer.usable:
|
78 |
y[index] = x
|
79 |
if layer.output:
|
80 |
-
output.
|
81 |
return output
|
82 |
|
83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
@@ -131,7 +131,7 @@ def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str =
|
|
131 |
logger.info("✅ Success load model")
|
132 |
if weight_path:
|
133 |
if os.path.exists(weight_path):
|
134 |
-
model.model.load_state_dict(torch.load(weight_path)
|
135 |
logger.info("✅ Success load model weight")
|
136 |
else:
|
137 |
logger.info(f"🌐 Weight {weight_path} not found, try downloading")
|
|
|
66 |
|
67 |
def forward(self, x):
|
68 |
y = {0: x}
|
69 |
+
output = dict()
|
70 |
for index, layer in enumerate(self.model, start=1):
|
71 |
if isinstance(layer.source, list):
|
72 |
model_input = [y[idx] for idx in layer.source]
|
|
|
77 |
if layer.usable:
|
78 |
y[index] = x
|
79 |
if layer.output:
|
80 |
+
output[layer.tags] = x
|
81 |
return output
|
82 |
|
83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
|
|
131 |
logger.info("✅ Success load model")
|
132 |
if weight_path:
|
133 |
if os.path.exists(weight_path):
|
134 |
+
model.model.load_state_dict(torch.load(weight_path))
|
135 |
logger.info("✅ Success load model weight")
|
136 |
else:
|
137 |
logger.info(f"🌐 Weight {weight_path} not found, try downloading")
|
yolo/tools/format_converters.py
CHANGED
@@ -17,7 +17,7 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
|
|
17 |
continue
|
18 |
_, _, conv_name, conv_idx, *details = weight_name.split(".")
|
19 |
if conv_name == "cv4" or conv_name == "cv5":
|
20 |
-
layer_idx =
|
21 |
else:
|
22 |
layer_idx = 37
|
23 |
|
|
|
17 |
continue
|
18 |
_, _, conv_name, conv_idx, *details = weight_name.split(".")
|
19 |
if conv_name == "cv4" or conv_name == "cv5":
|
20 |
+
layer_idx = 38
|
21 |
else:
|
22 |
layer_idx = 37
|
23 |
|
yolo/tools/loss_functions.py
CHANGED
@@ -2,14 +2,12 @@ from typing import Any, Dict, List, Tuple
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
-
from einops import rearrange
|
6 |
from loguru import logger
|
7 |
from torch import Tensor, nn
|
8 |
from torch.nn import BCEWithLogitsLoss
|
9 |
|
10 |
-
from yolo.config.config import Config
|
11 |
-
from yolo.utils.bounding_box_utils import BoxMatcher,
|
12 |
-
from yolo.utils.module_utils import divide_into_chunks
|
13 |
|
14 |
|
15 |
class BCELoss(nn.Module):
|
@@ -40,10 +38,9 @@ class BoxLoss(nn.Module):
|
|
40 |
|
41 |
|
42 |
class DFLoss(nn.Module):
|
43 |
-
def __init__(self,
|
44 |
super().__init__()
|
45 |
-
self.
|
46 |
-
self.scaler = scaler
|
47 |
self.reg_max = reg_max
|
48 |
|
49 |
def forward(
|
@@ -51,8 +48,9 @@ class DFLoss(nn.Module):
|
|
51 |
) -> Any:
|
52 |
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
|
53 |
bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
|
54 |
-
|
55 |
-
|
|
|
56 |
picked_targets = targets_dist[valid_bbox].view(-1)
|
57 |
picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
|
58 |
|
@@ -68,42 +66,31 @@ class DFLoss(nn.Module):
|
|
68 |
|
69 |
|
70 |
class YOLOLoss:
|
71 |
-
def __init__(self,
|
72 |
-
self.
|
73 |
-
self.
|
74 |
-
self.image_size = list(cfg.image_size)
|
75 |
-
self.strides = cfg.model.anchor.strides
|
76 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
77 |
-
|
78 |
-
self.anchors, self.scaler = generate_anchors(self.image_size, self.strides)
|
79 |
-
self.anchors = self.anchors.to(device)
|
80 |
-
self.scaler = self.scaler.to(device)
|
81 |
|
82 |
self.cls = BCELoss()
|
83 |
-
self.dfl = DFLoss(
|
84 |
self.iou = BoxLoss()
|
85 |
|
86 |
-
self.matcher = BoxMatcher(
|
87 |
|
88 |
def separate_anchor(self, anchors):
|
89 |
"""
|
90 |
separate anchor and bbouding box
|
91 |
"""
|
92 |
anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
|
93 |
-
anchors_box = anchors_box / self.scaler[None, :, None]
|
94 |
return anchors_cls, anchors_box
|
95 |
|
96 |
-
def __call__(
|
97 |
-
|
98 |
-
) -> Tuple[Tensor, Tensor, Tensor]:
|
99 |
-
# Batch_Size x (Anchor + Class) x H x W
|
100 |
-
# TODO: check datatype, why targets has a little bit error with origin version
|
101 |
-
|
102 |
# For each predicted targets, assign a best suitable ground truth box.
|
103 |
-
align_targets, valid_masks = self.matcher(targets, predicts_box)
|
104 |
|
105 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
106 |
-
|
107 |
|
108 |
cls_norm = targets_cls.sum()
|
109 |
box_norm = targets_cls.sum(-1)[valid_masks]
|
@@ -111,7 +98,7 @@ class YOLOLoss:
|
|
111 |
## -- CLS -- ##
|
112 |
loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
|
113 |
## -- IOU -- ##
|
114 |
-
loss_iou = self.iou(
|
115 |
## -- DFL -- ##
|
116 |
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
117 |
|
@@ -119,19 +106,22 @@ class YOLOLoss:
|
|
119 |
|
120 |
|
121 |
class DualLoss:
|
122 |
-
def __init__(self, cfg: Config) -> None:
|
123 |
-
|
124 |
-
self.
|
125 |
|
126 |
-
self.
|
127 |
-
self.dfl_rate = cfg.task.loss.objective["DFLoss"]
|
128 |
-
self.cls_rate = cfg.task.loss.objective["BCELoss"]
|
129 |
|
130 |
-
|
|
|
|
|
131 |
|
|
|
|
|
|
|
132 |
# TODO: Need Refactor this region, make it flexible!
|
133 |
-
aux_iou, aux_dfl, aux_cls = self.loss(
|
134 |
-
main_iou, main_dfl, main_cls = self.loss(
|
135 |
|
136 |
loss_dict = {
|
137 |
"BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
|
@@ -142,7 +132,7 @@ class DualLoss:
|
|
142 |
return loss_sum, loss_dict
|
143 |
|
144 |
|
145 |
-
def get_loss_function(cfg: Config) ->
|
146 |
-
loss_function = DualLoss(cfg)
|
147 |
logger.info("✅ Success load loss function")
|
148 |
return loss_function
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
5 |
from loguru import logger
|
6 |
from torch import Tensor, nn
|
7 |
from torch.nn import BCEWithLogitsLoss
|
8 |
|
9 |
+
from yolo.config.config import Config, LossConfig
|
10 |
+
from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
|
|
|
11 |
|
12 |
|
13 |
class BCELoss(nn.Module):
|
|
|
38 |
|
39 |
|
40 |
class DFLoss(nn.Module):
|
41 |
+
def __init__(self, anchors_norm: Tensor, reg_max: int) -> None:
|
42 |
super().__init__()
|
43 |
+
self.anchors_norm = anchors_norm
|
|
|
44 |
self.reg_max = reg_max
|
45 |
|
46 |
def forward(
|
|
|
48 |
) -> Any:
|
49 |
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
|
50 |
bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
|
51 |
+
targets_dist = torch.cat(((self.anchors_norm - bbox_lt), (bbox_rb - self.anchors_norm)), -1).clamp(
|
52 |
+
0, self.reg_max - 1.01
|
53 |
+
)
|
54 |
picked_targets = targets_dist[valid_bbox].view(-1)
|
55 |
picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
|
56 |
|
|
|
66 |
|
67 |
|
68 |
class YOLOLoss:
|
69 |
+
def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80, reg_max: int = 16) -> None:
|
70 |
+
self.class_num = class_num
|
71 |
+
self.vec2box = vec2box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
self.cls = BCELoss()
|
74 |
+
self.dfl = DFLoss(vec2box.anchor_norm, reg_max)
|
75 |
self.iou = BoxLoss()
|
76 |
|
77 |
+
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
|
78 |
|
79 |
def separate_anchor(self, anchors):
|
80 |
"""
|
81 |
separate anchor and bbouding box
|
82 |
"""
|
83 |
anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
|
84 |
+
anchors_box = anchors_box / self.vec2box.scaler[None, :, None]
|
85 |
return anchors_cls, anchors_box
|
86 |
|
87 |
+
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
88 |
+
predicts_cls, predicts_anc, predicts_box = predicts
|
|
|
|
|
|
|
|
|
89 |
# For each predicted targets, assign a best suitable ground truth box.
|
90 |
+
align_targets, valid_masks = self.matcher(targets, (predicts_cls, predicts_box))
|
91 |
|
92 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
93 |
+
predicts_box = predicts_box / self.vec2box.scaler[None, :, None]
|
94 |
|
95 |
cls_norm = targets_cls.sum()
|
96 |
box_norm = targets_cls.sum(-1)[valid_masks]
|
|
|
98 |
## -- CLS -- ##
|
99 |
loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
|
100 |
## -- IOU -- ##
|
101 |
+
loss_iou = self.iou(predicts_box, targets_bbox, valid_masks, box_norm, cls_norm)
|
102 |
## -- DFL -- ##
|
103 |
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
104 |
|
|
|
106 |
|
107 |
|
108 |
class DualLoss:
|
109 |
+
def __init__(self, cfg: Config, vec2box) -> None:
|
110 |
+
loss_cfg = cfg.task.loss
|
111 |
+
self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.class_num, reg_max=cfg.model.anchor.reg_max)
|
112 |
|
113 |
+
self.aux_rate = loss_cfg.aux
|
|
|
|
|
114 |
|
115 |
+
self.iou_rate = loss_cfg.objective["BoxLoss"]
|
116 |
+
self.dfl_rate = loss_cfg.objective["DFLoss"]
|
117 |
+
self.cls_rate = loss_cfg.objective["BCELoss"]
|
118 |
|
119 |
+
def __call__(
|
120 |
+
self, aux_predicts: List[Tensor], main_predicts: List[Tensor], targets: Tensor
|
121 |
+
) -> Tuple[Tensor, Dict[str, Tensor]]:
|
122 |
# TODO: Need Refactor this region, make it flexible!
|
123 |
+
aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
|
124 |
+
main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
|
125 |
|
126 |
loss_dict = {
|
127 |
"BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
|
|
|
132 |
return loss_sum, loss_dict
|
133 |
|
134 |
|
135 |
+
def get_loss_function(cfg: Config, vec2box) -> DualLoss:
|
136 |
+
loss_function = DualLoss(cfg, vec2box)
|
137 |
logger.info("✅ Success load loss function")
|
138 |
return loss_function
|
yolo/tools/solver.py
CHANGED
@@ -10,7 +10,7 @@ from yolo.model.yolo import YOLO
|
|
10 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
11 |
from yolo.tools.drawer import draw_bboxes
|
12 |
from yolo.tools.loss_functions import get_loss_function
|
13 |
-
from yolo.utils.bounding_box_utils import bbox_nms, calculate_map
|
14 |
from yolo.utils.logging_utils import ProgressTracker
|
15 |
from yolo.utils.model_utils import (
|
16 |
ExponentialMovingAverage,
|
@@ -20,13 +20,14 @@ from yolo.utils.model_utils import (
|
|
20 |
|
21 |
|
22 |
class ModelTrainer:
|
23 |
-
def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
|
24 |
train_cfg: TrainConfig = cfg.task
|
25 |
self.model = model
|
|
|
26 |
self.device = device
|
27 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
28 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
29 |
-
self.loss_fn = get_loss_function(cfg)
|
30 |
self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
|
31 |
self.num_epochs = cfg.task.epoch
|
32 |
|
@@ -45,7 +46,9 @@ class ModelTrainer:
|
|
45 |
|
46 |
with autocast():
|
47 |
outputs = self.model(data)
|
48 |
-
|
|
|
|
|
49 |
|
50 |
self.scaler.scale(loss).backward()
|
51 |
self.scaler.step(self.optimizer)
|
@@ -96,9 +99,10 @@ class ModelTrainer:
|
|
96 |
|
97 |
|
98 |
class ModelTester:
|
99 |
-
def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
|
100 |
self.model = model
|
101 |
self.device = device
|
|
|
102 |
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
103 |
|
104 |
self.nms = cfg.task.nms
|
@@ -112,8 +116,9 @@ class ModelTester:
|
|
112 |
for idx, images in enumerate(dataloader):
|
113 |
images = images.to(self.device)
|
114 |
with torch.no_grad():
|
115 |
-
|
116 |
-
|
|
|
117 |
draw_bboxes(
|
118 |
images[0],
|
119 |
nms_out[0],
|
|
|
10 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
11 |
from yolo.tools.drawer import draw_bboxes
|
12 |
from yolo.tools.loss_functions import get_loss_function
|
13 |
+
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
14 |
from yolo.utils.logging_utils import ProgressTracker
|
15 |
from yolo.utils.model_utils import (
|
16 |
ExponentialMovingAverage,
|
|
|
20 |
|
21 |
|
22 |
class ModelTrainer:
|
23 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
|
24 |
train_cfg: TrainConfig = cfg.task
|
25 |
self.model = model
|
26 |
+
self.vec2box = vec2box
|
27 |
self.device = device
|
28 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
29 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
30 |
+
self.loss_fn = get_loss_function(cfg, vec2box)
|
31 |
self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
|
32 |
self.num_epochs = cfg.task.epoch
|
33 |
|
|
|
46 |
|
47 |
with autocast():
|
48 |
outputs = self.model(data)
|
49 |
+
aux_predicts = self.vec2box(outputs["AUX"])
|
50 |
+
main_predicts = self.vec2box(outputs["Main"])
|
51 |
+
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
52 |
|
53 |
self.scaler.scale(loss).backward()
|
54 |
self.scaler.step(self.optimizer)
|
|
|
99 |
|
100 |
|
101 |
class ModelTester:
|
102 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
|
103 |
self.model = model
|
104 |
self.device = device
|
105 |
+
self.vec2box = vec2box
|
106 |
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
107 |
|
108 |
self.nms = cfg.task.nms
|
|
|
116 |
for idx, images in enumerate(dataloader):
|
117 |
images = images.to(self.device)
|
118 |
with torch.no_grad():
|
119 |
+
outputs = self.model(images)
|
120 |
+
outputs = self.vec2box(outputs["Main"])
|
121 |
+
nms_out = bbox_nms(outputs[0], outputs[2], self.nms)
|
122 |
draw_bboxes(
|
123 |
images[0],
|
124 |
nms_out[0],
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -106,12 +106,23 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
|
|
106 |
return bbox.to(dtype=data_type)
|
107 |
|
108 |
|
109 |
-
def generate_anchors(image_size: List[int],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
W, H = image_size
|
111 |
anchors = []
|
112 |
scaler = []
|
113 |
-
for
|
114 |
-
|
|
|
115 |
scaler.append(torch.full((anchor_num,), stride))
|
116 |
shift = stride // 2
|
117 |
x = torch.arange(0, W, stride) + shift
|
@@ -207,13 +218,13 @@ class BoxMatcher:
|
|
207 |
unique_indices = target_matrix.argmax(dim=1)
|
208 |
return unique_indices[..., None]
|
209 |
|
210 |
-
def __call__(self, target: Tensor, predict: Tensor) -> Tuple[Tensor, Tensor]:
|
211 |
"""
|
212 |
1. For each anchor prediction, find the highest suitability targets
|
213 |
2. Select the targets
|
214 |
2. Noramlize the class probilities of targets
|
215 |
"""
|
216 |
-
predict_cls, predict_bbox = predict
|
217 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
218 |
target_cls = target_cls.long().clamp(0)
|
219 |
|
@@ -251,9 +262,37 @@ class BoxMatcher:
|
|
251 |
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
252 |
|
253 |
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
# TODO change function to class or set 80 to class_num instead of a number
|
256 |
-
cls_dist, bbox = torch.split(predicts, [80, 4], dim=-1)
|
257 |
cls_dist = cls_dist.sigmoid()
|
258 |
|
259 |
# filter class by confidence
|
@@ -266,7 +305,7 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
|
266 |
batch_idx, *_ = torch.where(valid_mask)
|
267 |
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
|
268 |
predicts_nms = []
|
269 |
-
for idx in range(
|
270 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
271 |
|
272 |
predict_nms = torch.cat(
|
|
|
106 |
return bbox.to(dtype=data_type)
|
107 |
|
108 |
|
109 |
+
def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
|
110 |
+
"""
|
111 |
+
Find the anchor maps for each w, h.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
anchors_list List[[w1, h1], [w2, h2], ...]: the anchor num for each predicted anchor
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
all_anchors [HW x 2]:
|
118 |
+
all_scalers [HW]: The index of the best targets for each anchors
|
119 |
+
"""
|
120 |
W, H = image_size
|
121 |
anchors = []
|
122 |
scaler = []
|
123 |
+
for anchor_wh in anchors_list:
|
124 |
+
stride = W // anchor_wh[0]
|
125 |
+
anchor_num = anchor_wh[0] * anchor_wh[1]
|
126 |
scaler.append(torch.full((anchor_num,), stride))
|
127 |
shift = stride // 2
|
128 |
x = torch.arange(0, W, stride) + shift
|
|
|
218 |
unique_indices = target_matrix.argmax(dim=1)
|
219 |
return unique_indices[..., None]
|
220 |
|
221 |
+
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
222 |
"""
|
223 |
1. For each anchor prediction, find the highest suitability targets
|
224 |
2. Select the targets
|
225 |
2. Noramlize the class probilities of targets
|
226 |
"""
|
227 |
+
predict_cls, predict_bbox = predict
|
228 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
229 |
target_cls = target_cls.long().clamp(0)
|
230 |
|
|
|
262 |
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
263 |
|
264 |
|
265 |
+
class Vec2Box:
|
266 |
+
def __init__(self, model, image_size, device):
|
267 |
+
dummy_input = torch.zeros(1, 3, *image_size).to(device)
|
268 |
+
dummy_output = model(dummy_input)
|
269 |
+
anchors_num = []
|
270 |
+
for predict_head in dummy_output["Main"]:
|
271 |
+
_, _, *anchor_num = predict_head[2].shape
|
272 |
+
anchors_num.append(anchor_num)
|
273 |
+
anchor_grid, scaler = generate_anchors(image_size, anchors_num)
|
274 |
+
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
275 |
+
self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)
|
276 |
+
|
277 |
+
def __call__(self, predicts):
|
278 |
+
preds_cls, preds_anc, preds_box = [], [], []
|
279 |
+
for layer_output in predicts:
|
280 |
+
pred_cls, pred_anc, pred_box = layer_output
|
281 |
+
preds_cls.append(rearrange(pred_cls, "B C h w -> B (h w) C"))
|
282 |
+
preds_anc.append(rearrange(pred_anc, "B A R h w -> B (h w) R A"))
|
283 |
+
preds_box.append(rearrange(pred_box, "B X h w -> B (h w) X"))
|
284 |
+
preds_cls = torch.concat(preds_cls, dim=1)
|
285 |
+
preds_anc = torch.concat(preds_anc, dim=1)
|
286 |
+
preds_box = torch.concat(preds_box, dim=1)
|
287 |
+
|
288 |
+
pred_LTRB = preds_box * self.scaler.view(1, -1, 1)
|
289 |
+
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
290 |
+
preds_box = torch.cat([self.anchor_grid - lt, self.anchor_grid + rb], dim=-1)
|
291 |
+
return preds_cls, preds_anc, preds_box
|
292 |
+
|
293 |
+
|
294 |
+
def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig):
|
295 |
# TODO change function to class or set 80 to class_num instead of a number
|
|
|
296 |
cls_dist = cls_dist.sigmoid()
|
297 |
|
298 |
# filter class by confidence
|
|
|
305 |
batch_idx, *_ = torch.where(valid_mask)
|
306 |
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
|
307 |
predicts_nms = []
|
308 |
+
for idx in range(cls_dist.size(0)):
|
309 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
310 |
|
311 |
predict_nms = torch.cat(
|