🔨 [Add] the PostProccess class to convert predicts
Browse files- yolo/config/config.py +85 -0
- yolo/utils/model_utils.py +42 -4
yolo/config/config.py
CHANGED
@@ -142,6 +142,7 @@ class Config:
|
|
142 |
|
143 |
class_num: int
|
144 |
class_list: List[str]
|
|
|
145 |
image_size: List[int]
|
146 |
|
147 |
out_path: str
|
@@ -164,3 +165,87 @@ class YOLOLayer(nn.Module):
|
|
164 |
|
165 |
def __post_init__(self):
|
166 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
class_num: int
|
144 |
class_list: List[str]
|
145 |
+
class_idx_id: List[int]
|
146 |
image_size: List[int]
|
147 |
|
148 |
out_path: str
|
|
|
165 |
|
166 |
def __post_init__(self):
|
167 |
super().__init__()
|
168 |
+
|
169 |
+
|
170 |
+
IDX_TO_ID = [
|
171 |
+
1,
|
172 |
+
2,
|
173 |
+
3,
|
174 |
+
4,
|
175 |
+
5,
|
176 |
+
6,
|
177 |
+
7,
|
178 |
+
8,
|
179 |
+
9,
|
180 |
+
10,
|
181 |
+
11,
|
182 |
+
13,
|
183 |
+
14,
|
184 |
+
15,
|
185 |
+
16,
|
186 |
+
17,
|
187 |
+
18,
|
188 |
+
19,
|
189 |
+
20,
|
190 |
+
21,
|
191 |
+
22,
|
192 |
+
23,
|
193 |
+
24,
|
194 |
+
25,
|
195 |
+
27,
|
196 |
+
28,
|
197 |
+
31,
|
198 |
+
32,
|
199 |
+
33,
|
200 |
+
34,
|
201 |
+
35,
|
202 |
+
36,
|
203 |
+
37,
|
204 |
+
38,
|
205 |
+
39,
|
206 |
+
40,
|
207 |
+
41,
|
208 |
+
42,
|
209 |
+
43,
|
210 |
+
44,
|
211 |
+
46,
|
212 |
+
47,
|
213 |
+
48,
|
214 |
+
49,
|
215 |
+
50,
|
216 |
+
51,
|
217 |
+
52,
|
218 |
+
53,
|
219 |
+
54,
|
220 |
+
55,
|
221 |
+
56,
|
222 |
+
57,
|
223 |
+
58,
|
224 |
+
59,
|
225 |
+
60,
|
226 |
+
61,
|
227 |
+
62,
|
228 |
+
63,
|
229 |
+
64,
|
230 |
+
65,
|
231 |
+
67,
|
232 |
+
70,
|
233 |
+
72,
|
234 |
+
73,
|
235 |
+
74,
|
236 |
+
75,
|
237 |
+
76,
|
238 |
+
77,
|
239 |
+
78,
|
240 |
+
79,
|
241 |
+
80,
|
242 |
+
81,
|
243 |
+
82,
|
244 |
+
84,
|
245 |
+
85,
|
246 |
+
86,
|
247 |
+
87,
|
248 |
+
88,
|
249 |
+
89,
|
250 |
+
90,
|
251 |
+
]
|
yolo/utils/model_utils.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
import os
|
2 |
-
from
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from loguru import logger
|
7 |
from omegaconf import ListConfig
|
8 |
-
from torch import
|
9 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
from torch.optim import Optimizer
|
11 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
12 |
|
13 |
-
from yolo.config.config import OptimizerConfig, SchedulerConfig
|
14 |
from yolo.model.yolo import YOLO
|
|
|
15 |
|
16 |
|
17 |
class ExponentialMovingAverage:
|
@@ -93,3 +94,40 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
|
|
93 |
device_spec = initialize_distributed()
|
94 |
device = torch.device(device_spec)
|
95 |
return device, ddp_flag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Optional, Type, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
7 |
from loguru import logger
|
8 |
from omegaconf import ListConfig
|
9 |
+
from torch import Tensor
|
|
|
10 |
from torch.optim import Optimizer
|
11 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
12 |
|
13 |
+
from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
|
14 |
from yolo.model.yolo import YOLO
|
15 |
+
from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
|
16 |
|
17 |
|
18 |
class ExponentialMovingAverage:
|
|
|
94 |
device_spec = initialize_distributed()
|
95 |
device = torch.device(device_spec)
|
96 |
return device, ddp_flag
|
97 |
+
|
98 |
+
|
99 |
+
class PostProccess:
|
100 |
+
"""
|
101 |
+
TODO: function document
|
102 |
+
scale back the prediction and do nms for pred_bbox
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, vec2box, nms_cfg: NMSConfig) -> None:
|
106 |
+
self.vec2box = vec2box
|
107 |
+
self.nms = nms_cfg
|
108 |
+
|
109 |
+
def __call__(self, predict, rev_tensor: Optional[Tensor]):
|
110 |
+
pred_class, _, pred_bbox = self.vec2box(predict["Main"])
|
111 |
+
if rev_tensor is not None:
|
112 |
+
pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
|
113 |
+
pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms)
|
114 |
+
return pred_bbox
|
115 |
+
|
116 |
+
|
117 |
+
def predicts_to_json(img_paths, predicts):
|
118 |
+
"""
|
119 |
+
TODO: function document
|
120 |
+
turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
|
121 |
+
"""
|
122 |
+
batch_json = []
|
123 |
+
for img_path, bboxes in zip(img_paths, predicts):
|
124 |
+
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
125 |
+
for cls, *pos, conf in bboxes:
|
126 |
+
bbox = {
|
127 |
+
"image_id": int(Path(img_path).stem),
|
128 |
+
"category_id": IDX_TO_ID[int(cls)],
|
129 |
+
"bbox": [float(p) for p in pos],
|
130 |
+
"score": float(conf),
|
131 |
+
}
|
132 |
+
batch_json.append(bbox)
|
133 |
+
return batch_json
|