henry000 commited on
Commit
475302b
·
1 Parent(s): 3441a79

🔨 [Add] the PostProccess class to convert predicts

Browse files
Files changed (2) hide show
  1. yolo/config/config.py +85 -0
  2. yolo/utils/model_utils.py +42 -4
yolo/config/config.py CHANGED
@@ -142,6 +142,7 @@ class Config:
142
 
143
  class_num: int
144
  class_list: List[str]
 
145
  image_size: List[int]
146
 
147
  out_path: str
@@ -164,3 +165,87 @@ class YOLOLayer(nn.Module):
164
 
165
  def __post_init__(self):
166
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  class_num: int
144
  class_list: List[str]
145
+ class_idx_id: List[int]
146
  image_size: List[int]
147
 
148
  out_path: str
 
165
 
166
  def __post_init__(self):
167
  super().__init__()
168
+
169
+
170
+ IDX_TO_ID = [
171
+ 1,
172
+ 2,
173
+ 3,
174
+ 4,
175
+ 5,
176
+ 6,
177
+ 7,
178
+ 8,
179
+ 9,
180
+ 10,
181
+ 11,
182
+ 13,
183
+ 14,
184
+ 15,
185
+ 16,
186
+ 17,
187
+ 18,
188
+ 19,
189
+ 20,
190
+ 21,
191
+ 22,
192
+ 23,
193
+ 24,
194
+ 25,
195
+ 27,
196
+ 28,
197
+ 31,
198
+ 32,
199
+ 33,
200
+ 34,
201
+ 35,
202
+ 36,
203
+ 37,
204
+ 38,
205
+ 39,
206
+ 40,
207
+ 41,
208
+ 42,
209
+ 43,
210
+ 44,
211
+ 46,
212
+ 47,
213
+ 48,
214
+ 49,
215
+ 50,
216
+ 51,
217
+ 52,
218
+ 53,
219
+ 54,
220
+ 55,
221
+ 56,
222
+ 57,
223
+ 58,
224
+ 59,
225
+ 60,
226
+ 61,
227
+ 62,
228
+ 63,
229
+ 64,
230
+ 65,
231
+ 67,
232
+ 70,
233
+ 72,
234
+ 73,
235
+ 74,
236
+ 75,
237
+ 76,
238
+ 77,
239
+ 78,
240
+ 79,
241
+ 80,
242
+ 81,
243
+ 82,
244
+ 84,
245
+ 85,
246
+ 86,
247
+ 87,
248
+ 88,
249
+ 89,
250
+ 90,
251
+ ]
yolo/utils/model_utils.py CHANGED
@@ -1,17 +1,18 @@
1
  import os
2
- from typing import List, Type, Union
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from loguru import logger
7
  from omegaconf import ListConfig
8
- from torch import nn
9
- from torch.nn.parallel import DistributedDataParallel as DDP
10
  from torch.optim import Optimizer
11
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
12
 
13
- from yolo.config.config import OptimizerConfig, SchedulerConfig
14
  from yolo.model.yolo import YOLO
 
15
 
16
 
17
  class ExponentialMovingAverage:
@@ -93,3 +94,40 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
93
  device_spec = initialize_distributed()
94
  device = torch.device(device_spec)
95
  return device, ddp_flag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from pathlib import Path
3
+ from typing import List, Optional, Type, Union
4
 
5
  import torch
6
  import torch.distributed as dist
7
  from loguru import logger
8
  from omegaconf import ListConfig
9
+ from torch import Tensor
 
10
  from torch.optim import Optimizer
11
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
12
 
13
+ from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
14
  from yolo.model.yolo import YOLO
15
+ from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
16
 
17
 
18
  class ExponentialMovingAverage:
 
94
  device_spec = initialize_distributed()
95
  device = torch.device(device_spec)
96
  return device, ddp_flag
97
+
98
+
99
+ class PostProccess:
100
+ """
101
+ TODO: function document
102
+ scale back the prediction and do nms for pred_bbox
103
+ """
104
+
105
+ def __init__(self, vec2box, nms_cfg: NMSConfig) -> None:
106
+ self.vec2box = vec2box
107
+ self.nms = nms_cfg
108
+
109
+ def __call__(self, predict, rev_tensor: Optional[Tensor]):
110
+ pred_class, _, pred_bbox = self.vec2box(predict["Main"])
111
+ if rev_tensor is not None:
112
+ pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
113
+ pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms)
114
+ return pred_bbox
115
+
116
+
117
+ def predicts_to_json(img_paths, predicts):
118
+ """
119
+ TODO: function document
120
+ turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
121
+ """
122
+ batch_json = []
123
+ for img_path, bboxes in zip(img_paths, predicts):
124
+ bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
125
+ for cls, *pos, conf in bboxes:
126
+ bbox = {
127
+ "image_id": int(Path(img_path).stem),
128
+ "category_id": IDX_TO_ID[int(cls)],
129
+ "bbox": [float(p) for p in pos],
130
+ "score": float(conf),
131
+ }
132
+ batch_json.append(bbox)
133
+ return batch_json