henry000 commited on
Commit
c4cd90a
·
1 Parent(s): 46ebaf7

✨ [Add] new EMA! for torch lightning

Browse files
yolo/tools/solver.py CHANGED
@@ -33,6 +33,7 @@ class ValidateModel(BaseModel):
33
  self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
34
  self.metric.warn_on_many_detections = False
35
  self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
 
36
 
37
  def setup(self, stage):
38
  self.vec2box = create_converter(
@@ -45,7 +46,7 @@ class ValidateModel(BaseModel):
45
 
46
  def validation_step(self, batch, batch_idx):
47
  batch_size, images, targets, rev_tensor, img_paths = batch
48
- predicts = self.post_process(self(images), image_size=images.shape[2:])
49
  batch_metrics = self.metric(
50
  [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
51
  )
 
33
  self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
34
  self.metric.warn_on_many_detections = False
35
  self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
36
+ self.ema = self.model
37
 
38
  def setup(self, stage):
39
  self.vec2box = create_converter(
 
46
 
47
  def validation_step(self, batch, batch_idx):
48
  batch_size, images, targets, rev_tensor, img_paths = batch
49
+ predicts = self.post_process(self.ema(images), image_size=images.shape[2:])
50
  batch_metrics = self.metric(
51
  [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
52
  )
yolo/utils/logging_utils.py CHANGED
@@ -38,6 +38,7 @@ from typing_extensions import override
38
  from yolo.config.config import Config, YOLOLayer
39
  from yolo.model.yolo import YOLO
40
  from yolo.utils.logger import logger
 
41
  from yolo.utils.solver_utils import make_ap_table
42
 
43
 
@@ -255,6 +256,8 @@ def setup(cfg: Config):
255
 
256
  progress, loggers = [], []
257
 
 
 
258
  if quite:
259
  logger.setLevel(logging.ERROR)
260
  return progress, loggers, save_path
 
38
  from yolo.config.config import Config, YOLOLayer
39
  from yolo.model.yolo import YOLO
40
  from yolo.utils.logger import logger
41
+ from yolo.utils.model_utils import EMA
42
  from yolo.utils.solver_utils import make_ap_table
43
 
44
 
 
256
 
257
  progress, loggers = [], []
258
 
259
+ if hasattr(cfg.task, "ema") and cfg.task.ema.enable:
260
+ progress.append(EMA(cfg.task.ema.decay))
261
  if quite:
262
  logger.setLevel(logging.ERROR)
263
  return progress, loggers, save_path
yolo/utils/model_utils.py CHANGED
@@ -1,11 +1,16 @@
1
  import os
 
 
2
  from pathlib import Path
3
  from typing import List, Optional, Type, Union
4
 
5
  import torch
6
  import torch.distributed as dist
 
 
 
7
  from omegaconf import ListConfig
8
- from torch import Tensor
9
  from torch.optim import Optimizer
10
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
11
 
@@ -31,28 +36,31 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
31
  return start + (end - start) * step / total
32
 
33
 
34
- class ExponentialMovingAverage:
35
- def __init__(self, model: torch.nn.Module, decay: float):
36
- self.model = model
 
37
  self.decay = decay
38
- self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}
39
-
40
- def update(self):
41
- """Update the shadow parameters using the current model parameters."""
42
- for name, param in self.model.named_parameters():
43
- assert name in self.shadow, "All model parameters should have a corresponding shadow parameter."
44
- new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
45
- self.shadow[name] = new_average.clone()
46
-
47
- def apply_shadow(self):
48
- """Apply the shadow parameters to the model."""
49
- for name, param in self.model.named_parameters():
50
- param.data.copy_(self.shadow[name])
51
-
52
- def restore(self):
53
- """Restore the original parameters from the shadow."""
54
- for name, param in self.model.named_parameters():
55
- self.shadow[name].copy_(param.data)
 
 
56
 
57
 
58
  def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
 
1
  import os
2
+ from copy import deepcopy
3
+ from math import exp
4
  from pathlib import Path
5
  from typing import List, Optional, Type, Union
6
 
7
  import torch
8
  import torch.distributed as dist
9
+ from lightning import LightningModule, Trainer
10
+ from lightning.pytorch.callbacks import Callback
11
+ from lightning.pytorch.utilities import rank_zero_only
12
  from omegaconf import ListConfig
13
+ from torch import Tensor, no_grad
14
  from torch.optim import Optimizer
15
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
16
 
 
36
  return start + (end - start) * step / total
37
 
38
 
39
+ class EMA(Callback):
40
+ def __init__(self, decay: float = 0.9999, tau: float = 500):
41
+ super().__init__()
42
+ logger.info(":chart_with_upwards_trend: Enable Model EMA")
43
  self.decay = decay
44
+ self.tau = tau
45
+ self.step = 0
46
+
47
+ def setup(self, trainer, pl_module, stage):
48
+ pl_module.ema = deepcopy(pl_module.model)
49
+ self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
50
+
51
+ def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
52
+ for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
53
+ param.data.copy_(ema_param)
54
+ if dist.is_initialized():
55
+ dist.broadcast(param, src=0)
56
+
57
+ @rank_zero_only
58
+ @no_grad()
59
+ def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
60
+ self.step += 1
61
+ decay_factor = self.decay * (1 - exp(-self.step / self.tau))
62
+ for param, ema_param in zip(pl_module.parameters(), self.ema_parameters):
63
+ ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor))
64
 
65
 
66
  def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer: