henry000 commited on
Commit
c0153fd
Β·
2 Parent(s): 421f7e2 010502a

πŸ”€ [Merge] branch 'main' into DATASET

Browse files
demo/hf_demo.py CHANGED
@@ -10,8 +10,8 @@ sys.path.append(str(Path(__file__).resolve().parent.parent))
10
  from yolo import (
11
  AugmentationComposer,
12
  NMSConfig,
 
13
  Vec2Box,
14
- bbox_nms,
15
  create_model,
16
  draw_bboxes,
17
  )
@@ -37,7 +37,7 @@ transform = AugmentationComposer([])
37
 
38
 
39
  def predict(model_name, image, nms_confidence, nms_iou):
40
- global DEFAULT_MODEL, model, device, v2b, class_list
41
  if model_name != DEFAULT_MODEL:
42
  model = load_model(model_name, device)
43
  v2b = Vec2Box(model, IMAGE_SIZE, device)
@@ -46,16 +46,15 @@ def predict(model_name, image, nms_confidence, nms_iou):
46
  image_tensor, _, rev_tensor = transform(image)
47
 
48
  image_tensor = image_tensor.to(device)[None]
49
- rev_tensor = rev_tensor.to(device)
 
 
 
50
 
51
  with torch.no_grad():
52
  predict = model(image_tensor)
53
- pred_class, _, pred_bbox = v2b(predict["Main"])
54
-
55
- nms_config = NMSConfig(nms_confidence, nms_iou)
56
 
57
- pred_bbox = pred_bbox / rev_tensor[0] - rev_tensor[None, None, 1:]
58
- pred_bbox = bbox_nms(pred_class, pred_bbox, nms_config)
59
  result_image = draw_bboxes(image, pred_bbox, idx2label=class_list)
60
 
61
  return result_image
 
10
  from yolo import (
11
  AugmentationComposer,
12
  NMSConfig,
13
+ PostProccess,
14
  Vec2Box,
 
15
  create_model,
16
  draw_bboxes,
17
  )
 
37
 
38
 
39
  def predict(model_name, image, nms_confidence, nms_iou):
40
+ global DEFAULT_MODEL, model, device, v2b, class_list, post_proccess
41
  if model_name != DEFAULT_MODEL:
42
  model = load_model(model_name, device)
43
  v2b = Vec2Box(model, IMAGE_SIZE, device)
 
46
  image_tensor, _, rev_tensor = transform(image)
47
 
48
  image_tensor = image_tensor.to(device)[None]
49
+ rev_tensor = rev_tensor.to(device)[None]
50
+
51
+ nms_config = NMSConfig(nms_confidence, nms_iou)
52
+ post_proccess = PostProccess(v2b, nms_config)
53
 
54
  with torch.no_grad():
55
  predict = model(image_tensor)
56
+ pred_bbox = post_proccess(predict, rev_tensor)
 
 
57
 
 
 
58
  result_image = draw_bboxes(image, pred_bbox, idx2label=class_list)
59
 
60
  return result_image
tests/test_tools/{test_module_helper.py β†’ test_module_utils.py} RENAMED
@@ -2,7 +2,6 @@ import sys
2
  from pathlib import Path
3
 
4
  import pytest
5
- import torch
6
  from torch import nn
7
 
8
  project_root = Path(__file__).resolve().parent.parent.parent
 
2
  from pathlib import Path
3
 
4
  import pytest
 
5
  from torch import nn
6
 
7
  project_root = Path(__file__).resolve().parent.parent.parent
tests/test_utils/{test_dataaugment.py β†’ test_data_augmentation.py} RENAMED
@@ -54,7 +54,7 @@ def test_mosaic():
54
 
55
  # Mock parent with image_size and get_more_data method
56
  class MockParent:
57
- image_size = 100
58
 
59
  def get_more_data(self, num_images):
60
  return [(img, boxes) for _ in range(num_images)]
 
54
 
55
  # Mock parent with image_size and get_more_data method
56
  class MockParent:
57
+ image_size = (100, 100)
58
 
59
  def get_more_data(self, num_images):
60
  return [(img, boxes) for _ in range(num_images)]
tests/test_utils/{test_loss.py β†’ test_loss_functions.py} RENAMED
File without changes
yolo/__init__.py CHANGED
@@ -6,6 +6,7 @@ from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
 
9
 
10
  all = [
11
  "create_model",
@@ -22,4 +23,5 @@ all = [
22
  "ModelTester",
23
  "ModelTrainer",
24
  "ModelValidator",
 
25
  ]
 
6
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
+ from yolo.utils.model_utils import PostProccess
10
 
11
  all = [
12
  "create_model",
 
23
  "ModelTester",
24
  "ModelTrainer",
25
  "ModelValidator",
26
+ "PostProccess",
27
  ]
yolo/config/config.py CHANGED
@@ -142,6 +142,7 @@ class Config:
142
 
143
  class_num: int
144
  class_list: List[str]
 
145
  image_size: List[int]
146
 
147
  out_path: str
@@ -164,3 +165,87 @@ class YOLOLayer(nn.Module):
164
 
165
  def __post_init__(self):
166
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  class_num: int
144
  class_list: List[str]
145
+ class_idx_id: List[int]
146
  image_size: List[int]
147
 
148
  out_path: str
 
165
 
166
  def __post_init__(self):
167
  super().__init__()
168
+
169
+
170
+ IDX_TO_ID = [
171
+ 1,
172
+ 2,
173
+ 3,
174
+ 4,
175
+ 5,
176
+ 6,
177
+ 7,
178
+ 8,
179
+ 9,
180
+ 10,
181
+ 11,
182
+ 13,
183
+ 14,
184
+ 15,
185
+ 16,
186
+ 17,
187
+ 18,
188
+ 19,
189
+ 20,
190
+ 21,
191
+ 22,
192
+ 23,
193
+ 24,
194
+ 25,
195
+ 27,
196
+ 28,
197
+ 31,
198
+ 32,
199
+ 33,
200
+ 34,
201
+ 35,
202
+ 36,
203
+ 37,
204
+ 38,
205
+ 39,
206
+ 40,
207
+ 41,
208
+ 42,
209
+ 43,
210
+ 44,
211
+ 46,
212
+ 47,
213
+ 48,
214
+ 49,
215
+ 50,
216
+ 51,
217
+ 52,
218
+ 53,
219
+ 54,
220
+ 55,
221
+ 56,
222
+ 57,
223
+ 58,
224
+ 59,
225
+ 60,
226
+ 61,
227
+ 62,
228
+ 63,
229
+ 64,
230
+ 65,
231
+ 67,
232
+ 70,
233
+ 72,
234
+ 73,
235
+ 74,
236
+ 75,
237
+ 76,
238
+ 77,
239
+ 78,
240
+ 79,
241
+ 80,
242
+ 81,
243
+ 82,
244
+ 84,
245
+ 85,
246
+ 86,
247
+ 87,
248
+ 88,
249
+ 89,
250
+ 90,
251
+ ]
yolo/lazy.py CHANGED
@@ -9,7 +9,7 @@ sys.path.append(str(project_root))
9
  from yolo.config.config import Config
10
  from yolo.model.yolo import create_model
11
  from yolo.tools.data_loader import create_dataloader
12
- from yolo.tools.solver import ModelTester, ModelTrainer
13
  from yolo.utils.bounding_box_utils import Vec2Box
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import ProgressLogger
@@ -37,6 +37,10 @@ def main(cfg: Config):
37
  tester = ModelTester(cfg, model, vec2box, progress, device)
38
  tester.solve(dataloader)
39
 
 
 
 
 
40
 
41
  if __name__ == "__main__":
42
  main()
 
9
  from yolo.config.config import Config
10
  from yolo.model.yolo import create_model
11
  from yolo.tools.data_loader import create_dataloader
12
+ from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13
  from yolo.utils.bounding_box_utils import Vec2Box
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import ProgressLogger
 
37
  tester = ModelTester(cfg, model, vec2box, progress, device)
38
  tester.solve(dataloader)
39
 
40
+ if cfg.task.task == "validation":
41
+ valider = ModelValidator(cfg.task, model, vec2box, progress, device)
42
+ valider.solve(dataloader)
43
+
44
 
45
  if __name__ == "__main__":
46
  main()
yolo/tools/loss_functions.py CHANGED
@@ -39,9 +39,9 @@ class BoxLoss(nn.Module):
39
 
40
 
41
  class DFLoss(nn.Module):
42
- def __init__(self, anchors_norm: Tensor, reg_max: int) -> None:
43
  super().__init__()
44
- self.anchors_norm = anchors_norm
45
  self.reg_max = reg_max
46
 
47
  def forward(
@@ -72,7 +72,7 @@ class YOLOLoss:
72
  self.vec2box = vec2box
73
 
74
  self.cls = BCELoss()
75
- self.dfl = DFLoss(vec2box.anchor_norm, reg_max)
76
  self.iou = BoxLoss()
77
 
78
  self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
 
39
 
40
 
41
  class DFLoss(nn.Module):
42
+ def __init__(self, vec2box: Vec2Box, reg_max: int) -> None:
43
  super().__init__()
44
+ self.anchors_norm = (vec2box.anchor_grid / vec2box.scaler[:, None])[None]
45
  self.reg_max = reg_max
46
 
47
  def forward(
 
72
  self.vec2box = vec2box
73
 
74
  self.cls = BCELoss()
75
+ self.dfl = DFLoss(vec2box, reg_max)
76
  self.iou = BoxLoss()
77
 
78
  self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
yolo/tools/solver.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import time
3
 
@@ -15,12 +16,14 @@ from yolo.model.yolo import YOLO
15
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
16
  from yolo.tools.drawer import draw_bboxes, draw_model
17
  from yolo.tools.loss_functions import create_loss_function
18
- from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
19
  from yolo.utils.logging_utils import ProgressLogger, log_model_structure
20
  from yolo.utils.model_utils import (
21
  ExponentialMovingAverage,
 
22
  create_optimizer,
23
  create_scheduler,
 
24
  )
25
 
26
 
@@ -116,10 +119,9 @@ class ModelTester:
116
  def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
117
  self.model = model
118
  self.device = device
119
- self.vec2box = vec2box
120
  self.progress = progress
121
 
122
- self.nms = cfg.task.nms
123
  self.save_path = os.path.join(progress.save_path, "images")
124
  os.makedirs(self.save_path, exist_ok=True)
125
  self.save_predict = getattr(cfg.task, "save_predict", None)
@@ -141,9 +143,8 @@ class ModelTester:
141
  rev_tensor = rev_tensor.to(self.device)
142
  with torch.no_grad():
143
  predicts = self.model(images)
144
- predicts = self.vec2box(predicts["Main"])
145
- nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
146
- img = draw_bboxes(images, nms_out, idx2label=self.idx2label)
147
 
148
  if dataloader.is_stream:
149
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@@ -176,32 +177,29 @@ class ModelValidator:
176
  validation_cfg: ValidationConfig,
177
  model: YOLO,
178
  vec2box: Vec2Box,
179
- device,
180
  progress: ProgressLogger,
 
181
  ):
182
  self.model = model
183
- self.vec2box = vec2box
184
  self.device = device
185
  self.progress = progress
186
 
187
- self.nms = validation_cfg.nms
 
188
 
189
  def solve(self, dataloader):
190
  # logger.info("πŸ§ͺ Start Validation!")
191
  self.model.eval()
192
- # TODO: choice mAP metrics?
193
- iou_thresholds = torch.arange(0.5, 1.0, 0.05)
194
- map_all = []
195
  self.progress.start_one_epoch(len(dataloader))
196
  for images, targets, rev_tensor, img_paths in dataloader:
197
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
198
  with torch.no_grad():
199
  predicts = self.model(images)
200
- predicts = self.vec2box(predicts["Main"])
201
- nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
202
- for idx, predict in enumerate(nms_out):
203
- map_value = calculate_map(predict, targets[idx], iou_thresholds)
204
- map_all.append(map_value[0])
205
- self.progress.one_batch(mapp=torch.Tensor(map_all).mean())
206
 
 
207
  self.progress.finish_one_epoch()
 
 
 
1
+ import json
2
  import os
3
  import time
4
 
 
16
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
17
  from yolo.tools.drawer import draw_bboxes, draw_model
18
  from yolo.tools.loss_functions import create_loss_function
19
+ from yolo.utils.bounding_box_utils import Vec2Box
20
  from yolo.utils.logging_utils import ProgressLogger, log_model_structure
21
  from yolo.utils.model_utils import (
22
  ExponentialMovingAverage,
23
+ PostProccess,
24
  create_optimizer,
25
  create_scheduler,
26
+ predicts_to_json,
27
  )
28
 
29
 
 
119
  def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
120
  self.model = model
121
  self.device = device
 
122
  self.progress = progress
123
 
124
+ self.post_proccess = PostProccess(vec2box, cfg.task.nms)
125
  self.save_path = os.path.join(progress.save_path, "images")
126
  os.makedirs(self.save_path, exist_ok=True)
127
  self.save_predict = getattr(cfg.task, "save_predict", None)
 
143
  rev_tensor = rev_tensor.to(self.device)
144
  with torch.no_grad():
145
  predicts = self.model(images)
146
+ predicts = self.post_proccess(predicts, rev_tensor)
147
+ img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
 
148
 
149
  if dataloader.is_stream:
150
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
 
177
  validation_cfg: ValidationConfig,
178
  model: YOLO,
179
  vec2box: Vec2Box,
 
180
  progress: ProgressLogger,
181
+ device,
182
  ):
183
  self.model = model
 
184
  self.device = device
185
  self.progress = progress
186
 
187
+ self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
188
+ self.json_path = os.path.join(self.progress.save_path, f"predict.json")
189
 
190
  def solve(self, dataloader):
191
  # logger.info("πŸ§ͺ Start Validation!")
192
  self.model.eval()
193
+ predict_json = []
 
 
194
  self.progress.start_one_epoch(len(dataloader))
195
  for images, targets, rev_tensor, img_paths in dataloader:
196
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
197
  with torch.no_grad():
198
  predicts = self.model(images)
199
+ predicts = self.post_proccess(predicts, rev_tensor)
200
+ self.progress.one_batch()
 
 
 
 
201
 
202
+ predict_json.extend(predicts_to_json(img_paths, predicts))
203
  self.progress.finish_one_epoch()
204
+ with open(self.json_path, "w") as f:
205
+ json.dump(predict_json, f)
yolo/utils/bounding_box_utils.py CHANGED
@@ -108,12 +108,13 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
108
  return bbox.to(dtype=data_type)
109
 
110
 
111
- def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
112
  """
113
  Find the anchor maps for each w, h.
114
 
115
  Args:
116
- anchors_list List[[w1, h1], [w2, h2], ...]: the anchor num for each predicted anchor
 
117
 
118
  Returns:
119
  all_anchors [HW x 2]:
@@ -122,15 +123,14 @@ def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
122
  W, H = image_size
123
  anchors = []
124
  scaler = []
125
- for anchor_wh in anchors_list:
126
- stride = W // anchor_wh[0]
127
- anchor_num = anchor_wh[0] * anchor_wh[1]
128
  scaler.append(torch.full((anchor_num,), stride))
129
  shift = stride // 2
130
- x = torch.arange(0, W, stride) + shift
131
- y = torch.arange(0, H, stride) + shift
132
- anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
133
- anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
134
  anchors.append(anchor)
135
  all_anchors = torch.cat(anchors, dim=0)
136
  all_scalers = torch.cat(scaler, dim=0)
@@ -172,6 +172,7 @@ class BoxMatcher:
172
  Returns:
173
  [batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
174
  """
 
175
  target_cls = target_cls.expand(-1, -1, 8400)
176
  predict_cls = predict_cls.transpose(1, 2)
177
  cls_probabilities = torch.gather(predict_cls, 1, target_cls)
@@ -266,24 +267,34 @@ class BoxMatcher:
266
 
267
  class Vec2Box:
268
  def __init__(self, model: YOLO, image_size, device):
269
- if getattr(model, "strides", None) is None:
270
- logger.info("🧸 Found no anchor, Make a dummy test for auto-anchor size")
271
- dummy_input = torch.zeros(1, 3, *image_size).to(device)
272
- dummy_output = model(dummy_input)
273
- anchors_num = []
274
- for predict_head in dummy_output["Main"]:
275
- _, _, *anchor_num = predict_head[2].shape
276
- anchors_num.append(anchor_num)
277
  else:
278
- logger.info(f"🈢 Found anchor {model.strides}")
279
- anchors_num = [[image_size[0] // stride, image_size[0] // stride] for stride in model.strides]
280
 
 
281
  if not isinstance(model, YOLO):
282
  device = torch.device("cpu")
283
 
284
- anchor_grid, scaler = generate_anchors(image_size, anchors_num)
285
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
286
- self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  def __call__(self, predicts):
289
  preds_cls, preds_anc, preds_box = [], [], []
 
108
  return bbox.to(dtype=data_type)
109
 
110
 
111
+ def generate_anchors(image_size: List[int], strides: List[int]):
112
  """
113
  Find the anchor maps for each w, h.
114
 
115
  Args:
116
+ image_size List: the image size of augmented image size
117
+ strides List[8, 16, 32, ...]: the stride size for each predicted layer
118
 
119
  Returns:
120
  all_anchors [HW x 2]:
 
123
  W, H = image_size
124
  anchors = []
125
  scaler = []
126
+ for stride in strides:
127
+ anchor_num = W // stride * H // stride
 
128
  scaler.append(torch.full((anchor_num,), stride))
129
  shift = stride // 2
130
+ h = torch.arange(0, H, stride) + shift
131
+ w = torch.arange(0, W, stride) + shift
132
+ anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
133
+ anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
134
  anchors.append(anchor)
135
  all_anchors = torch.cat(anchors, dim=0)
136
  all_scalers = torch.cat(scaler, dim=0)
 
172
  Returns:
173
  [batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
174
  """
175
+ # TODO: Turn 8400 to HW
176
  target_cls = target_cls.expand(-1, -1, 8400)
177
  predict_cls = predict_cls.transpose(1, 2)
178
  cls_probabilities = torch.gather(predict_cls, 1, target_cls)
 
267
 
268
  class Vec2Box:
269
  def __init__(self, model: YOLO, image_size, device):
270
+ self.device = device
271
+
272
+ if getattr(model, "strides"):
273
+ logger.info(f"🈢 Found stride of model {model.strides}")
274
+ self.strides = model.strides
 
 
 
275
  else:
276
+ logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
277
+ self.strides = self.create_auto_anchor(model, image_size)
278
 
279
+ # TODO: this is a exception of onnx, remove it when onnx device if fixed
280
  if not isinstance(model, YOLO):
281
  device = torch.device("cpu")
282
 
283
+ anchor_grid, scaler = generate_anchors(image_size, self.strides)
284
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
285
+
286
+ def create_auto_anchor(self, model: YOLO, image_size):
287
+ dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
288
+ dummy_output = model(dummy_input)
289
+ strides = []
290
+ for predict_head in dummy_output["Main"]:
291
+ _, _, *anchor_num = predict_head[2].shape
292
+ strides.append(image_size[1] // anchor_num[1])
293
+ return strides
294
+
295
+ def update(self, image_size):
296
+ anchor_grid, scaler = generate_anchors(image_size, self.strides)
297
+ self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device)
298
 
299
  def __call__(self, predicts):
300
  preds_cls, preds_anc, preds_box = [], [], []
yolo/utils/logging_utils.py CHANGED
@@ -72,9 +72,9 @@ class ProgressLogger:
72
  self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
73
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
74
 
75
- def one_batch(self, loss_dict: Dict[str, Tensor] = None, mapp=None):
76
  if loss_dict is None:
77
- self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{mapp:.2%}")
78
  return
79
  if self.use_wandb:
80
  for loss_name, loss_value in loss_dict.items():
 
72
  self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
73
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
74
 
75
+ def one_batch(self, loss_dict: Dict[str, Tensor] = None):
76
  if loss_dict is None:
77
+ self.progress.update(self.batch_task, advance=1, description=f"[green]Validating")
78
  return
79
  if self.use_wandb:
80
  for loss_name, loss_value in loss_dict.items():
yolo/utils/model_utils.py CHANGED
@@ -1,17 +1,18 @@
1
  import os
2
- from typing import List, Type, Union
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from loguru import logger
7
  from omegaconf import ListConfig
8
- from torch import nn
9
- from torch.nn.parallel import DistributedDataParallel as DDP
10
  from torch.optim import Optimizer
11
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
12
 
13
- from yolo.config.config import OptimizerConfig, SchedulerConfig
14
  from yolo.model.yolo import YOLO
 
15
 
16
 
17
  class ExponentialMovingAverage:
@@ -93,3 +94,40 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
93
  device_spec = initialize_distributed()
94
  device = torch.device(device_spec)
95
  return device, ddp_flag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from pathlib import Path
3
+ from typing import List, Optional, Type, Union
4
 
5
  import torch
6
  import torch.distributed as dist
7
  from loguru import logger
8
  from omegaconf import ListConfig
9
+ from torch import Tensor
 
10
  from torch.optim import Optimizer
11
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
12
 
13
+ from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
14
  from yolo.model.yolo import YOLO
15
+ from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
16
 
17
 
18
  class ExponentialMovingAverage:
 
94
  device_spec = initialize_distributed()
95
  device = torch.device(device_spec)
96
  return device, ddp_flag
97
+
98
+
99
+ class PostProccess:
100
+ """
101
+ TODO: function document
102
+ scale back the prediction and do nms for pred_bbox
103
+ """
104
+
105
+ def __init__(self, vec2box, nms_cfg: NMSConfig) -> None:
106
+ self.vec2box = vec2box
107
+ self.nms = nms_cfg
108
+
109
+ def __call__(self, predict, rev_tensor: Optional[Tensor]):
110
+ pred_class, _, pred_bbox = self.vec2box(predict["Main"])
111
+ if rev_tensor is not None:
112
+ pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
113
+ pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms)
114
+ return pred_bbox
115
+
116
+
117
+ def predicts_to_json(img_paths, predicts):
118
+ """
119
+ TODO: function document
120
+ turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
121
+ """
122
+ batch_json = []
123
+ for img_path, bboxes in zip(img_paths, predicts):
124
+ bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
125
+ for cls, *pos, conf in bboxes:
126
+ bbox = {
127
+ "image_id": int(Path(img_path).stem),
128
+ "category_id": IDX_TO_ID[int(cls)],
129
+ "bbox": [float(p) for p in pos],
130
+ "score": float(conf),
131
+ }
132
+ batch_json.append(bbox)
133
+ return batch_json