π [Merge] branch 'TRAIN' into TEST
Browse files- yolo/config/config.py +1 -1
- yolo/config/task/train.yaml +1 -1
- yolo/config/task/validation.yaml +1 -1
- yolo/tools/data_loader.py +9 -1
- yolo/tools/solver.py +9 -5
- yolo/utils/bounding_box_utils.py +5 -4
- yolo/utils/dataset_utils.py +8 -1
- yolo/utils/logging_utils.py +3 -1
- yolo/utils/model_utils.py +53 -21
yolo/config/config.py
CHANGED
@@ -97,7 +97,7 @@ class SchedulerConfig:
|
|
97 |
|
98 |
@dataclass
|
99 |
class EMAConfig:
|
100 |
-
|
101 |
decay: float
|
102 |
|
103 |
|
|
|
97 |
|
98 |
@dataclass
|
99 |
class EMAConfig:
|
100 |
+
enable: bool
|
101 |
decay: float
|
102 |
|
103 |
|
yolo/config/task/train.yaml
CHANGED
@@ -50,5 +50,5 @@ scheduler:
|
|
50 |
end_factor: 0.01
|
51 |
|
52 |
ema:
|
53 |
-
|
54 |
decay: 0.995
|
|
|
50 |
end_factor: 0.01
|
51 |
|
52 |
ema:
|
53 |
+
enable: true
|
54 |
decay: 0.995
|
yolo/config/task/validation.yaml
CHANGED
@@ -7,7 +7,7 @@ data:
|
|
7 |
shuffle: False
|
8 |
pin_memory: True
|
9 |
data_augment: {}
|
10 |
-
dynamic_shape:
|
11 |
nms:
|
12 |
min_confidence: 0.0001
|
13 |
min_iou: 0.7
|
|
|
7 |
shuffle: False
|
8 |
pin_memory: True
|
9 |
data_augment: {}
|
10 |
+
dynamic_shape: False
|
11 |
nms:
|
12 |
min_confidence: 0.0001
|
13 |
min_iou: 0.7
|
yolo/tools/data_loader.py
CHANGED
@@ -56,7 +56,15 @@ class YoloDataset(Dataset):
|
|
56 |
data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
|
57 |
torch.save(data, cache_path)
|
58 |
else:
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
logger.info(f":package: Loaded {phase_name} cache")
|
61 |
return data
|
62 |
|
|
|
56 |
data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
|
57 |
torch.save(data, cache_path)
|
58 |
else:
|
59 |
+
try:
|
60 |
+
data = torch.load(cache_path, weights_only=False)
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(
|
63 |
+
f":rotating_light: Failed to load the cache at '{cache_path}'.\n"
|
64 |
+
":rotating_light: This may be caused by using cache from different other YOLO.\n"
|
65 |
+
":rotating_light: Please clean the cache and try running again."
|
66 |
+
)
|
67 |
+
raise e
|
68 |
logger.info(f":package: Loaded {phase_name} cache")
|
69 |
return data
|
70 |
|
yolo/tools/solver.py
CHANGED
@@ -33,6 +33,7 @@ class ValidateModel(BaseModel):
|
|
33 |
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
34 |
self.metric.warn_on_many_detections = False
|
35 |
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
|
|
36 |
|
37 |
def setup(self, stage):
|
38 |
self.vec2box = create_converter(
|
@@ -46,7 +47,7 @@ class ValidateModel(BaseModel):
|
|
46 |
def validation_step(self, batch, batch_idx):
|
47 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
48 |
H, W = images.shape[2:]
|
49 |
-
predicts = self.post_process(self(images), image_size=[W, H])
|
50 |
batch_metrics = self.metric(
|
51 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
52 |
)
|
@@ -56,7 +57,6 @@ class ValidateModel(BaseModel):
|
|
56 |
"map": batch_metrics["map"],
|
57 |
"map_50": batch_metrics["map_50"],
|
58 |
},
|
59 |
-
on_step=True,
|
60 |
batch_size=batch_size,
|
61 |
)
|
62 |
return predicts
|
@@ -64,9 +64,11 @@ class ValidateModel(BaseModel):
|
|
64 |
def on_validation_epoch_end(self):
|
65 |
epoch_metrics = self.metric.compute()
|
66 |
del epoch_metrics["classes"]
|
67 |
-
self.log_dict(epoch_metrics, prog_bar=True, rank_zero_only=True)
|
68 |
self.log_dict(
|
69 |
-
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
|
|
|
|
|
70 |
)
|
71 |
self.metric.reset()
|
72 |
|
@@ -85,7 +87,9 @@ class TrainModel(ValidateModel):
|
|
85 |
return self.train_loader
|
86 |
|
87 |
def on_train_epoch_start(self):
|
88 |
-
self.trainer.optimizers[0].next_epoch(
|
|
|
|
|
89 |
self.vec2box.update(self.cfg.image_size)
|
90 |
|
91 |
def training_step(self, batch, batch_idx):
|
|
|
33 |
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
34 |
self.metric.warn_on_many_detections = False
|
35 |
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
36 |
+
self.ema = self.model
|
37 |
|
38 |
def setup(self, stage):
|
39 |
self.vec2box = create_converter(
|
|
|
47 |
def validation_step(self, batch, batch_idx):
|
48 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
49 |
H, W = images.shape[2:]
|
50 |
+
predicts = self.post_process(self.ema(images), image_size=[W, H])
|
51 |
batch_metrics = self.metric(
|
52 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
53 |
)
|
|
|
57 |
"map": batch_metrics["map"],
|
58 |
"map_50": batch_metrics["map_50"],
|
59 |
},
|
|
|
60 |
batch_size=batch_size,
|
61 |
)
|
62 |
return predicts
|
|
|
64 |
def on_validation_epoch_end(self):
|
65 |
epoch_metrics = self.metric.compute()
|
66 |
del epoch_metrics["classes"]
|
67 |
+
self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True)
|
68 |
self.log_dict(
|
69 |
+
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
|
70 |
+
sync_dist=True,
|
71 |
+
rank_zero_only=True,
|
72 |
)
|
73 |
self.metric.reset()
|
74 |
|
|
|
87 |
return self.train_loader
|
88 |
|
89 |
def on_train_epoch_start(self):
|
90 |
+
self.trainer.optimizers[0].next_epoch(
|
91 |
+
ceil(len(self.train_loader) / self.trainer.world_size), self.current_epoch
|
92 |
+
)
|
93 |
self.vec2box.update(self.cfg.image_size)
|
94 |
|
95 |
def training_step(self, batch, batch_idx):
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -212,19 +212,20 @@ class BoxMatcher:
|
|
212 |
topk_masks = topk_targets > 0
|
213 |
return topk_targets, topk_masks
|
214 |
|
215 |
-
def filter_duplicates(self,
|
216 |
"""
|
217 |
Filter the maximum suitability target index of each anchor.
|
218 |
|
219 |
Args:
|
220 |
-
|
221 |
|
222 |
Returns:
|
223 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
224 |
"""
|
225 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
226 |
-
max_idx = F.one_hot(
|
227 |
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
|
|
228 |
unique_indices = topk_mask.argmax(dim=1)
|
229 |
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
230 |
|
@@ -278,7 +279,7 @@ class BoxMatcher:
|
|
278 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
279 |
|
280 |
# delete one anchor pred assign to mutliple gts
|
281 |
-
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
|
282 |
|
283 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
284 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
|
|
212 |
topk_masks = topk_targets > 0
|
213 |
return topk_targets, topk_masks
|
214 |
|
215 |
+
def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
|
216 |
"""
|
217 |
Filter the maximum suitability target index of each anchor.
|
218 |
|
219 |
Args:
|
220 |
+
iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
|
221 |
|
222 |
Returns:
|
223 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
224 |
"""
|
225 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
226 |
+
max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
|
227 |
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
228 |
+
topk_mask &= grid_mask
|
229 |
unique_indices = topk_mask.argmax(dim=1)
|
230 |
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
231 |
|
|
|
279 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
280 |
|
281 |
# delete one anchor pred assign to mutliple gts
|
282 |
+
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
|
283 |
|
284 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
285 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
yolo/utils/dataset_utils.py
CHANGED
@@ -115,7 +115,14 @@ def scale_segmentation(
|
|
115 |
|
116 |
|
117 |
def tensorlize(data):
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
max_box = max(bbox.size(0) for bbox in bboxes)
|
120 |
padded_bbox_list = []
|
121 |
for bbox in bboxes:
|
|
|
115 |
|
116 |
|
117 |
def tensorlize(data):
|
118 |
+
try:
|
119 |
+
img_paths, bboxes, img_ratios = zip(*data)
|
120 |
+
except ValueError as e:
|
121 |
+
logger.error(
|
122 |
+
":rotating_light: This may be caused by using old cache or another version of YOLO's cache.\n"
|
123 |
+
":rotating_light: Please clean the cache and try running again."
|
124 |
+
)
|
125 |
+
raise e
|
126 |
max_box = max(bbox.size(0) for bbox in bboxes)
|
127 |
padded_bbox_list = []
|
128 |
for bbox in bboxes:
|
yolo/utils/logging_utils.py
CHANGED
@@ -38,6 +38,7 @@ from typing_extensions import override
|
|
38 |
from yolo.config.config import Config, YOLOLayer
|
39 |
from yolo.model.yolo import YOLO
|
40 |
from yolo.utils.logger import logger
|
|
|
41 |
from yolo.utils.solver_utils import make_ap_table
|
42 |
|
43 |
|
@@ -97,7 +98,6 @@ class YOLORichProgressBar(RichProgressBar):
|
|
97 |
)
|
98 |
self.max_result = 0
|
99 |
self.past_results.clear()
|
100 |
-
self.progress.update(self.task_epoch, advance=-0.5)
|
101 |
|
102 |
@override
|
103 |
@rank_zero_only
|
@@ -255,6 +255,8 @@ def setup(cfg: Config):
|
|
255 |
|
256 |
progress, loggers = [], []
|
257 |
|
|
|
|
|
258 |
if quite:
|
259 |
logger.setLevel(logging.ERROR)
|
260 |
return progress, loggers, save_path
|
|
|
38 |
from yolo.config.config import Config, YOLOLayer
|
39 |
from yolo.model.yolo import YOLO
|
40 |
from yolo.utils.logger import logger
|
41 |
+
from yolo.utils.model_utils import EMA
|
42 |
from yolo.utils.solver_utils import make_ap_table
|
43 |
|
44 |
|
|
|
98 |
)
|
99 |
self.max_result = 0
|
100 |
self.past_results.clear()
|
|
|
101 |
|
102 |
@override
|
103 |
@rank_zero_only
|
|
|
255 |
|
256 |
progress, loggers = [], []
|
257 |
|
258 |
+
if hasattr(cfg.task, "ema") and cfg.task.ema.enable:
|
259 |
+
progress.append(EMA(cfg.task.ema.decay))
|
260 |
if quite:
|
261 |
logger.setLevel(logging.ERROR)
|
262 |
return progress, loggers, save_path
|
yolo/utils/model_utils.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
import os
|
|
|
|
|
2 |
from pathlib import Path
|
3 |
from typing import List, Optional, Type, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
|
|
|
|
|
|
7 |
from omegaconf import ListConfig
|
8 |
-
from torch import Tensor
|
9 |
from torch.optim import Optimizer
|
10 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
11 |
|
@@ -15,28 +20,48 @@ from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, transform_
|
|
15 |
from yolo.utils.logger import logger
|
16 |
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
self.decay = decay
|
22 |
-
self.
|
|
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
self.shadow[name] = new_average.clone()
|
30 |
|
31 |
-
def
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
@@ -57,9 +82,15 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
57 |
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
|
58 |
]
|
59 |
|
60 |
-
def next_epoch(self, batch_num):
|
61 |
self.min_lr = self.max_lr
|
62 |
self.max_lr = [param["lr"] for param in self.param_groups]
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
self.batch_num = batch_num
|
64 |
self.batch_idx = 0
|
65 |
|
@@ -68,7 +99,8 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
68 |
lr_dict = dict()
|
69 |
for lr_idx, param_group in enumerate(self.param_groups):
|
70 |
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
71 |
-
param_group["lr"] = min_lr
|
|
|
72 |
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
|
73 |
return lr_dict
|
74 |
|
|
|
1 |
import os
|
2 |
+
from copy import deepcopy
|
3 |
+
from math import exp
|
4 |
from pathlib import Path
|
5 |
from typing import List, Optional, Type, Union
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
9 |
+
from lightning import LightningModule, Trainer
|
10 |
+
from lightning.pytorch.callbacks import Callback
|
11 |
+
from lightning.pytorch.utilities import rank_zero_only
|
12 |
from omegaconf import ListConfig
|
13 |
+
from torch import Tensor, no_grad
|
14 |
from torch.optim import Optimizer
|
15 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
16 |
|
|
|
20 |
from yolo.utils.logger import logger
|
21 |
|
22 |
|
23 |
+
def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
|
24 |
+
"""
|
25 |
+
Linearly interpolates between start and end values.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
start (float): The starting value.
|
29 |
+
end (float): The ending value.
|
30 |
+
step (int): The current step in the interpolation process.
|
31 |
+
total (int): The total number of steps.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
float: The interpolated value.
|
35 |
+
"""
|
36 |
+
return start + (end - start) * step / total
|
37 |
+
|
38 |
+
|
39 |
+
class EMA(Callback):
|
40 |
+
def __init__(self, decay: float = 0.9999, tau: float = 500):
|
41 |
+
super().__init__()
|
42 |
+
logger.info(":chart_with_upwards_trend: Enable Model EMA")
|
43 |
self.decay = decay
|
44 |
+
self.tau = tau
|
45 |
+
self.step = 0
|
46 |
|
47 |
+
def setup(self, trainer, pl_module, stage):
|
48 |
+
pl_module.ema = deepcopy(pl_module.model)
|
49 |
+
self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
|
50 |
+
for param in pl_module.ema.parameters():
|
51 |
+
param.requires_grad = False
|
|
|
52 |
|
53 |
+
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
54 |
+
for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
|
55 |
+
param.data.copy_(ema_param)
|
56 |
+
trainer.strategy.broadcast(param)
|
57 |
|
58 |
+
@rank_zero_only
|
59 |
+
@no_grad()
|
60 |
+
def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
|
61 |
+
self.step += 1
|
62 |
+
decay_factor = self.decay * (1 - exp(-self.step / self.tau))
|
63 |
+
for param, ema_param in zip(pl_module.parameters(), self.ema_parameters):
|
64 |
+
ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor))
|
65 |
|
66 |
|
67 |
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
|
82 |
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
|
83 |
]
|
84 |
|
85 |
+
def next_epoch(self, batch_num, epoch_idx):
|
86 |
self.min_lr = self.max_lr
|
87 |
self.max_lr = [param["lr"] for param in self.param_groups]
|
88 |
+
# TODO: load momentum from config instead a fix number
|
89 |
+
# 0.937: Start Momentum
|
90 |
+
# 0.8 : Normal Momemtum
|
91 |
+
# 3 : The warm up epoch num
|
92 |
+
self.min_mom = lerp(0.937, 0.8, max(epoch_idx, 3), 3)
|
93 |
+
self.max_mom = lerp(0.937, 0.8, max(epoch_idx + 1, 3), 3)
|
94 |
self.batch_num = batch_num
|
95 |
self.batch_idx = 0
|
96 |
|
|
|
99 |
lr_dict = dict()
|
100 |
for lr_idx, param_group in enumerate(self.param_groups):
|
101 |
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
102 |
+
param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
|
103 |
+
param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
|
104 |
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
|
105 |
return lr_dict
|
106 |
|