✨ [Add] new EMA! for torch lightning
Browse files- yolo/tools/solver.py +2 -1
- yolo/utils/logging_utils.py +3 -0
- yolo/utils/model_utils.py +30 -22
yolo/tools/solver.py
CHANGED
@@ -33,6 +33,7 @@ class ValidateModel(BaseModel):
|
|
33 |
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
34 |
self.metric.warn_on_many_detections = False
|
35 |
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
|
|
36 |
|
37 |
def setup(self, stage):
|
38 |
self.vec2box = create_converter(
|
@@ -45,7 +46,7 @@ class ValidateModel(BaseModel):
|
|
45 |
|
46 |
def validation_step(self, batch, batch_idx):
|
47 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
48 |
-
predicts = self.post_process(self(images), image_size=images.shape[2:])
|
49 |
batch_metrics = self.metric(
|
50 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
51 |
)
|
|
|
33 |
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
34 |
self.metric.warn_on_many_detections = False
|
35 |
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
36 |
+
self.ema = self.model
|
37 |
|
38 |
def setup(self, stage):
|
39 |
self.vec2box = create_converter(
|
|
|
46 |
|
47 |
def validation_step(self, batch, batch_idx):
|
48 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
49 |
+
predicts = self.post_process(self.ema(images), image_size=images.shape[2:])
|
50 |
batch_metrics = self.metric(
|
51 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
52 |
)
|
yolo/utils/logging_utils.py
CHANGED
@@ -38,6 +38,7 @@ from typing_extensions import override
|
|
38 |
from yolo.config.config import Config, YOLOLayer
|
39 |
from yolo.model.yolo import YOLO
|
40 |
from yolo.utils.logger import logger
|
|
|
41 |
from yolo.utils.solver_utils import make_ap_table
|
42 |
|
43 |
|
@@ -255,6 +256,8 @@ def setup(cfg: Config):
|
|
255 |
|
256 |
progress, loggers = [], []
|
257 |
|
|
|
|
|
258 |
if quite:
|
259 |
logger.setLevel(logging.ERROR)
|
260 |
return progress, loggers, save_path
|
|
|
38 |
from yolo.config.config import Config, YOLOLayer
|
39 |
from yolo.model.yolo import YOLO
|
40 |
from yolo.utils.logger import logger
|
41 |
+
from yolo.utils.model_utils import EMA
|
42 |
from yolo.utils.solver_utils import make_ap_table
|
43 |
|
44 |
|
|
|
256 |
|
257 |
progress, loggers = [], []
|
258 |
|
259 |
+
if hasattr(cfg.task, "ema") and cfg.task.ema.enable:
|
260 |
+
progress.append(EMA(cfg.task.ema.decay))
|
261 |
if quite:
|
262 |
logger.setLevel(logging.ERROR)
|
263 |
return progress, loggers, save_path
|
yolo/utils/model_utils.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
import os
|
|
|
|
|
2 |
from pathlib import Path
|
3 |
from typing import List, Optional, Type, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
|
|
|
|
|
|
7 |
from omegaconf import ListConfig
|
8 |
-
from torch import Tensor
|
9 |
from torch.optim import Optimizer
|
10 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
11 |
|
@@ -31,28 +36,31 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
|
|
31 |
return start + (end - start) * step / total
|
32 |
|
33 |
|
34 |
-
class
|
35 |
-
def __init__(self,
|
36 |
-
|
|
|
37 |
self.decay = decay
|
38 |
-
self.
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
|
57 |
|
58 |
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
|
1 |
import os
|
2 |
+
from copy import deepcopy
|
3 |
+
from math import exp
|
4 |
from pathlib import Path
|
5 |
from typing import List, Optional, Type, Union
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
9 |
+
from lightning import LightningModule, Trainer
|
10 |
+
from lightning.pytorch.callbacks import Callback
|
11 |
+
from lightning.pytorch.utilities import rank_zero_only
|
12 |
from omegaconf import ListConfig
|
13 |
+
from torch import Tensor, no_grad
|
14 |
from torch.optim import Optimizer
|
15 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
16 |
|
|
|
36 |
return start + (end - start) * step / total
|
37 |
|
38 |
|
39 |
+
class EMA(Callback):
|
40 |
+
def __init__(self, decay: float = 0.9999, tau: float = 500):
|
41 |
+
super().__init__()
|
42 |
+
logger.info(":chart_with_upwards_trend: Enable Model EMA")
|
43 |
self.decay = decay
|
44 |
+
self.tau = tau
|
45 |
+
self.step = 0
|
46 |
+
|
47 |
+
def setup(self, trainer, pl_module, stage):
|
48 |
+
pl_module.ema = deepcopy(pl_module.model)
|
49 |
+
self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
|
50 |
+
|
51 |
+
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
52 |
+
for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
|
53 |
+
param.data.copy_(ema_param)
|
54 |
+
if dist.is_initialized():
|
55 |
+
dist.broadcast(param, src=0)
|
56 |
+
|
57 |
+
@rank_zero_only
|
58 |
+
@no_grad()
|
59 |
+
def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
|
60 |
+
self.step += 1
|
61 |
+
decay_factor = self.decay * (1 - exp(-self.step / self.tau))
|
62 |
+
for param, ema_param in zip(pl_module.parameters(), self.ema_parameters):
|
63 |
+
ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor))
|
64 |
|
65 |
|
66 |
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|