π [Move] parse_predict to Anchor2Box class
Browse files- yolo/tools/bbox_helper.py +42 -1
- yolo/utils/loss.py +4 -44
yolo/tools/bbox_helper.py
CHANGED
@@ -3,9 +3,10 @@ from typing import List, Tuple
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
|
|
6 |
from torch import Tensor
|
7 |
|
8 |
-
from yolo.config.config import MatcherConfig
|
9 |
|
10 |
|
11 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
@@ -122,6 +123,46 @@ def make_anchor(image_size: List[int], strides: List[int], device):
|
|
122 |
return all_anchors, all_scalers
|
123 |
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
class BoxMatcher:
|
126 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
127 |
self.class_num = class_num
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
from torch import Tensor
|
8 |
|
9 |
+
from yolo.config.config import Config, MatcherConfig
|
10 |
|
11 |
|
12 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
|
|
123 |
return all_anchors, all_scalers
|
124 |
|
125 |
|
126 |
+
class Anchor2Box:
|
127 |
+
def __init__(self, cfg: Config, device: torch.device) -> None:
|
128 |
+
self.reg_max = cfg.model.anchor.reg_max
|
129 |
+
self.class_num = cfg.hyper.data.class_num
|
130 |
+
self.image_size = list(cfg.hyper.data.image_size)
|
131 |
+
self.strides = cfg.model.anchor.strides
|
132 |
+
|
133 |
+
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
134 |
+
self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
|
135 |
+
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
136 |
+
|
137 |
+
def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
|
138 |
+
"""
|
139 |
+
args:
|
140 |
+
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
141 |
+
return:
|
142 |
+
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
143 |
+
"""
|
144 |
+
preds = []
|
145 |
+
for pred in predicts:
|
146 |
+
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
147 |
+
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
148 |
+
|
149 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
150 |
+
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
151 |
+
if with_logits:
|
152 |
+
preds_cls = preds_cls.sigmoid()
|
153 |
+
|
154 |
+
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
155 |
+
|
156 |
+
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
157 |
+
pred_minXY = self.anchors - lt
|
158 |
+
pred_maxXY = self.anchors + rb
|
159 |
+
preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1)
|
160 |
+
|
161 |
+
predicts = torch.cat([preds_cls, preds_box], dim=-1)
|
162 |
+
|
163 |
+
return predicts, preds_anc
|
164 |
+
|
165 |
+
|
166 |
class BoxMatcher:
|
167 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
168 |
self.class_num = class_num
|
yolo/utils/loss.py
CHANGED
@@ -8,12 +8,7 @@ from torch import Tensor, nn
|
|
8 |
from torch.nn import BCEWithLogitsLoss
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
-
from yolo.tools.bbox_helper import
|
12 |
-
BoxMatcher,
|
13 |
-
calculate_iou,
|
14 |
-
make_anchor,
|
15 |
-
transform_bbox,
|
16 |
-
)
|
17 |
from yolo.tools.module_helper import make_chunk
|
18 |
|
19 |
|
@@ -90,42 +85,7 @@ class YOLOLoss:
|
|
90 |
self.iou = BoxLoss()
|
91 |
|
92 |
self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
|
93 |
-
|
94 |
-
def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
|
95 |
-
"""
|
96 |
-
args:
|
97 |
-
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
98 |
-
return:
|
99 |
-
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
100 |
-
"""
|
101 |
-
preds = []
|
102 |
-
for pred in predicts:
|
103 |
-
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
104 |
-
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
105 |
-
|
106 |
-
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
107 |
-
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
108 |
-
|
109 |
-
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
110 |
-
|
111 |
-
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
112 |
-
pred_minXY = self.anchors - lt
|
113 |
-
pred_maxXY = self.anchors + rb
|
114 |
-
predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
|
115 |
-
|
116 |
-
return predicts, preds_anc
|
117 |
-
|
118 |
-
def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
|
119 |
-
"""
|
120 |
-
return List:
|
121 |
-
"""
|
122 |
-
targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
|
123 |
-
bbox_num = targets[:, 0].int().bincount()
|
124 |
-
batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
|
125 |
-
for instance_idx, bbox_num in enumerate(bbox_num):
|
126 |
-
instance_targets = targets[targets[:, 0] == instance_idx]
|
127 |
-
batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
|
128 |
-
return batch_targets
|
129 |
|
130 |
def separate_anchor(self, anchors):
|
131 |
"""
|
@@ -138,10 +98,10 @@ class YOLOLoss:
|
|
138 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
139 |
# Batch_Size x (Anchor + Class) x H x W
|
140 |
# TODO: check datatype, why targets has a little bit error with origin version
|
141 |
-
predicts, predicts_anc = self.
|
142 |
|
|
|
143 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
144 |
-
# calculate loss between with instance and predict
|
145 |
|
146 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
147 |
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|
|
|
8 |
from torch.nn import BCEWithLogitsLoss
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
+
from yolo.tools.bbox_helper import Anchor2Box, BoxMatcher, calculate_iou, make_anchor
|
|
|
|
|
|
|
|
|
|
|
12 |
from yolo.tools.module_helper import make_chunk
|
13 |
|
14 |
|
|
|
85 |
self.iou = BoxLoss()
|
86 |
|
87 |
self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
|
88 |
+
self.box_converter = Anchor2Box(cfg, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def separate_anchor(self, anchors):
|
91 |
"""
|
|
|
98 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
99 |
# Batch_Size x (Anchor + Class) x H x W
|
100 |
# TODO: check datatype, why targets has a little bit error with origin version
|
101 |
+
predicts, predicts_anc = self.box_converter(predicts)
|
102 |
|
103 |
+
# For each predicted targets, assign a best suitable ground truth box.
|
104 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
|
|
105 |
|
106 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
107 |
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|