henry000 commited on
Commit
b2baf14
·
1 Parent(s): ce50967

✨ [New] validation code!

Browse files
yolo/tools/drawer.py CHANGED
@@ -37,7 +37,7 @@ def draw_bboxes(
37
  font = ImageFont.load_default(30)
38
 
39
  for bbox in bboxes:
40
- class_id, x_min, y_min, x_max, y_max = bbox
41
  if scaled_bbox:
42
  x_min = x_min * width
43
  x_max = x_max * width
 
37
  font = ImageFont.load_default(30)
38
 
39
  for bbox in bboxes:
40
+ class_id, x_min, y_min, x_max, y_max, *conf = bbox
41
  if scaled_bbox:
42
  x_min = x_min * width
43
  x_max = x_max * width
yolo/tools/solver.py CHANGED
@@ -30,6 +30,12 @@ class ModelTrainer:
30
  self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
31
  self.num_epochs = cfg.task.epoch
32
 
 
 
 
 
 
 
33
  if getattr(train_cfg.ema, "enabled", False):
34
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
35
  else:
@@ -89,9 +95,7 @@ class ModelTrainer:
89
  epoch_loss = self.train_one_epoch(dataloader)
90
  self.progress.finish_one_epoch()
91
 
92
- logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
93
- if (epoch + 1) % 5 == 0:
94
- self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
95
 
96
 
97
  class ModelTester:
@@ -100,7 +104,7 @@ class ModelTester:
100
  self.device = device
101
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
102
 
103
- self.anchor2box = AnchorBoxConverter(cfg, device)
104
  self.nms = cfg.task.nms
105
  self.save_path = save_path
106
 
@@ -125,3 +129,45 @@ class ModelTester:
125
  else:
126
  raise e
127
  dataloader.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
31
  self.num_epochs = cfg.task.epoch
32
 
33
+ validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
34
+ anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
35
+ self.validator = ModelValidator(
36
+ cfg.task.validation, model, save_path, device, self.progress, anchor2box, validation_dataloader
37
+ )
38
+
39
  if getattr(train_cfg.ema, "enabled", False):
40
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
41
  else:
 
95
  epoch_loss = self.train_one_epoch(dataloader)
96
  self.progress.finish_one_epoch()
97
 
98
+ self.validator.solve()
 
 
99
 
100
 
101
  class ModelTester:
 
104
  self.device = device
105
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
106
 
107
+ self.anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
108
  self.nms = cfg.task.nms
109
  self.save_path = save_path
110
 
 
129
  else:
130
  raise e
131
  dataloader.stop()
132
+
133
+
134
+ class ModelValidator:
135
+ def __init__(
136
+ self,
137
+ validation_cfg: ValidationConfig,
138
+ model: YOLO,
139
+ save_path: str,
140
+ device,
141
+ progress: ProgressTracker,
142
+ anchor2box,
143
+ validation_dataloader,
144
+ ):
145
+ self.model = model
146
+ self.device = device
147
+ self.progress = progress
148
+ self.save_path = save_path
149
+
150
+ self.anchor2box = anchor2box
151
+ self.nms = validation_cfg.nms
152
+ self.validdataloader = validation_dataloader
153
+
154
+ def solve(self):
155
+ # logger.info("🧪 Start Validation!")
156
+ self.model.eval()
157
+
158
+ iou_thresholds = torch.arange(0.5, 1.0, 0.05)
159
+ map_all = []
160
+ self.progress.start_one_epoch(len(self.validdataloader))
161
+ for data, targets in self.validdataloader:
162
+ data, targets = data.to(self.device), targets.to(self.device)
163
+ with torch.no_grad():
164
+ raw_output = self.model(data)
165
+ predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
166
+
167
+ nms_out = bbox_nms(predict, self.nms)
168
+ for idx, predict in enumerate(nms_out):
169
+ map_value = calculate_map(predict, targets[idx], iou_thresholds)
170
+ map_all.append(map_value[0])
171
+ self.progress.one_batch(mapp=torch.Tensor(map_all).mean())
172
+
173
+ self.progress.finish_one_epoch()
yolo/utils/bounding_box_utils.py CHANGED
@@ -297,6 +297,7 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
297
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
298
  valid_mask = cls_val > nms_cfg.min_confidence
299
  valid_cls = cls_idx[valid_mask].float()
 
300
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
301
 
302
  batch_idx, *_ = torch.where(valid_mask)
@@ -305,7 +306,52 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
305
  for idx in range(predicts.size(0)):
306
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
307
 
308
- predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
 
 
309
 
310
  predicts_nms.append(predict_nms)
311
  return predicts_nms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
298
  valid_mask = cls_val > nms_cfg.min_confidence
299
  valid_cls = cls_idx[valid_mask].float()
300
+ valid_con = cls_val[valid_mask].float()
301
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
302
 
303
  batch_idx, *_ = torch.where(valid_mask)
 
306
  for idx in range(predicts.size(0)):
307
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
308
 
309
+ predict_nms = torch.cat(
310
+ [valid_cls[instance_idx][:, None], valid_con[instance_idx][:, None], valid_box[instance_idx]], dim=-1
311
+ )
312
 
313
  predicts_nms.append(predict_nms)
314
  return predicts_nms
315
+
316
+
317
+ def calculate_map(predictions, ground_truths, iou_thresholds):
318
+ # TODO: Refactor this block
319
+ device = predictions.device
320
+ n_preds = predictions.size(0)
321
+ n_gts = (ground_truths[:, 0] != -1).sum()
322
+ ground_truths = ground_truths[:n_gts]
323
+ aps = []
324
+
325
+ ious = calculate_iou(predictions[:, 2:], ground_truths[:, 1:]) # [n_preds, n_gts]
326
+
327
+ for threshold in iou_thresholds:
328
+ tp = torch.zeros(n_preds, device=device)
329
+ fp = torch.zeros(n_preds, device=device)
330
+
331
+ max_iou, max_indices = torch.max(ious, dim=1)
332
+ above_threshold = max_iou >= threshold
333
+ matched_classes = predictions[:, 0] == ground_truths[max_indices, 0]
334
+ tp[above_threshold & matched_classes] = 1
335
+ fp[above_threshold & ~matched_classes] = 1
336
+ fp[max_iou < threshold] = 1
337
+
338
+ _, indices = torch.sort(predictions[:, 1], descending=True)
339
+ tp = tp[indices]
340
+ fp = fp[indices]
341
+
342
+ tp_cumsum = torch.cumsum(tp, dim=0)
343
+ fp_cumsum = torch.cumsum(fp, dim=0)
344
+
345
+ precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
346
+ recall = tp_cumsum / (n_gts + 1e-6)
347
+
348
+ recall_thresholds = torch.arange(0, 1, 0.1)
349
+ precision_at_recall = torch.zeros_like(recall_thresholds)
350
+ for i, r in enumerate(recall_thresholds):
351
+ precision_at_recall[i] = precision[recall >= r].max().item() if torch.any(recall >= r) else 0
352
+
353
+ ap = precision_at_recall.mean()
354
+ aps.append(ap)
355
+
356
+ mean_ap = torch.mean(torch.stack(aps))
357
+ return mean_ap, aps