henry000 commited on
Commit
2ab865c
Β·
2 Parent(s): c3ee284 89a6526

πŸ”€ [Merge] branch 'TRAIN' into TEST

Browse files
yolo/config/config.py CHANGED
@@ -97,7 +97,7 @@ class SchedulerConfig:
97
 
98
  @dataclass
99
  class EMAConfig:
100
- enabled: bool
101
  decay: float
102
 
103
 
 
97
 
98
  @dataclass
99
  class EMAConfig:
100
+ enable: bool
101
  decay: float
102
 
103
 
yolo/config/task/train.yaml CHANGED
@@ -50,5 +50,5 @@ scheduler:
50
  end_factor: 0.01
51
 
52
  ema:
53
- enabled: true
54
  decay: 0.995
 
50
  end_factor: 0.01
51
 
52
  ema:
53
+ enable: true
54
  decay: 0.995
yolo/config/task/validation.yaml CHANGED
@@ -7,7 +7,7 @@ data:
7
  shuffle: False
8
  pin_memory: True
9
  data_augment: {}
10
- dynamic_shape: True
11
  nms:
12
  min_confidence: 0.0001
13
  min_iou: 0.7
 
7
  shuffle: False
8
  pin_memory: True
9
  data_augment: {}
10
+ dynamic_shape: False
11
  nms:
12
  min_confidence: 0.0001
13
  min_iou: 0.7
yolo/tools/data_loader.py CHANGED
@@ -56,7 +56,15 @@ class YoloDataset(Dataset):
56
  data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
57
  torch.save(data, cache_path)
58
  else:
59
- data = torch.load(cache_path, weights_only=False)
 
 
 
 
 
 
 
 
60
  logger.info(f":package: Loaded {phase_name} cache")
61
  return data
62
 
 
56
  data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
57
  torch.save(data, cache_path)
58
  else:
59
+ try:
60
+ data = torch.load(cache_path, weights_only=False)
61
+ except Exception as e:
62
+ logger.error(
63
+ f":rotating_light: Failed to load the cache at '{cache_path}'.\n"
64
+ ":rotating_light: This may be caused by using cache from different other YOLO.\n"
65
+ ":rotating_light: Please clean the cache and try running again."
66
+ )
67
+ raise e
68
  logger.info(f":package: Loaded {phase_name} cache")
69
  return data
70
 
yolo/tools/solver.py CHANGED
@@ -33,6 +33,7 @@ class ValidateModel(BaseModel):
33
  self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
34
  self.metric.warn_on_many_detections = False
35
  self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
 
36
 
37
  def setup(self, stage):
38
  self.vec2box = create_converter(
@@ -46,7 +47,7 @@ class ValidateModel(BaseModel):
46
  def validation_step(self, batch, batch_idx):
47
  batch_size, images, targets, rev_tensor, img_paths = batch
48
  H, W = images.shape[2:]
49
- predicts = self.post_process(self(images), image_size=[W, H])
50
  batch_metrics = self.metric(
51
  [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
52
  )
@@ -56,7 +57,6 @@ class ValidateModel(BaseModel):
56
  "map": batch_metrics["map"],
57
  "map_50": batch_metrics["map_50"],
58
  },
59
- on_step=True,
60
  batch_size=batch_size,
61
  )
62
  return predicts
@@ -64,9 +64,11 @@ class ValidateModel(BaseModel):
64
  def on_validation_epoch_end(self):
65
  epoch_metrics = self.metric.compute()
66
  del epoch_metrics["classes"]
67
- self.log_dict(epoch_metrics, prog_bar=True, rank_zero_only=True)
68
  self.log_dict(
69
- {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True
 
 
70
  )
71
  self.metric.reset()
72
 
@@ -85,7 +87,9 @@ class TrainModel(ValidateModel):
85
  return self.train_loader
86
 
87
  def on_train_epoch_start(self):
88
- self.trainer.optimizers[0].next_epoch(ceil(len(self.train_loader) / self.trainer.world_size))
 
 
89
  self.vec2box.update(self.cfg.image_size)
90
 
91
  def training_step(self, batch, batch_idx):
 
33
  self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
34
  self.metric.warn_on_many_detections = False
35
  self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
36
+ self.ema = self.model
37
 
38
  def setup(self, stage):
39
  self.vec2box = create_converter(
 
47
  def validation_step(self, batch, batch_idx):
48
  batch_size, images, targets, rev_tensor, img_paths = batch
49
  H, W = images.shape[2:]
50
+ predicts = self.post_process(self.ema(images), image_size=[W, H])
51
  batch_metrics = self.metric(
52
  [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
53
  )
 
57
  "map": batch_metrics["map"],
58
  "map_50": batch_metrics["map_50"],
59
  },
 
60
  batch_size=batch_size,
61
  )
62
  return predicts
 
64
  def on_validation_epoch_end(self):
65
  epoch_metrics = self.metric.compute()
66
  del epoch_metrics["classes"]
67
+ self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True)
68
  self.log_dict(
69
+ {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
70
+ sync_dist=True,
71
+ rank_zero_only=True,
72
  )
73
  self.metric.reset()
74
 
 
87
  return self.train_loader
88
 
89
  def on_train_epoch_start(self):
90
+ self.trainer.optimizers[0].next_epoch(
91
+ ceil(len(self.train_loader) / self.trainer.world_size), self.current_epoch
92
+ )
93
  self.vec2box.update(self.cfg.image_size)
94
 
95
  def training_step(self, batch, batch_idx):
yolo/utils/bounding_box_utils.py CHANGED
@@ -212,19 +212,20 @@ class BoxMatcher:
212
  topk_masks = topk_targets > 0
213
  return topk_targets, topk_masks
214
 
215
- def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
216
  """
217
  Filter the maximum suitability target index of each anchor.
218
 
219
  Args:
220
- target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
221
 
222
  Returns:
223
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
224
  """
225
  duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
226
- max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
227
  topk_mask = torch.where(duplicates, max_idx, topk_mask)
 
228
  unique_indices = topk_mask.argmax(dim=1)
229
  return unique_indices[..., None], topk_mask.sum(1), topk_mask
230
 
@@ -278,7 +279,7 @@ class BoxMatcher:
278
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
279
 
280
  # delete one anchor pred assign to mutliple gts
281
- unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
282
 
283
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
284
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
 
212
  topk_masks = topk_targets > 0
213
  return topk_targets, topk_masks
214
 
215
+ def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
216
  """
217
  Filter the maximum suitability target index of each anchor.
218
 
219
  Args:
220
+ iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
221
 
222
  Returns:
223
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
224
  """
225
  duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
226
+ max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
227
  topk_mask = torch.where(duplicates, max_idx, topk_mask)
228
+ topk_mask &= grid_mask
229
  unique_indices = topk_mask.argmax(dim=1)
230
  return unique_indices[..., None], topk_mask.sum(1), topk_mask
231
 
 
279
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
280
 
281
  # delete one anchor pred assign to mutliple gts
282
+ unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
283
 
284
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
285
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
yolo/utils/dataset_utils.py CHANGED
@@ -115,7 +115,14 @@ def scale_segmentation(
115
 
116
 
117
  def tensorlize(data):
118
- img_paths, bboxes, img_ratios = zip(*data)
 
 
 
 
 
 
 
119
  max_box = max(bbox.size(0) for bbox in bboxes)
120
  padded_bbox_list = []
121
  for bbox in bboxes:
 
115
 
116
 
117
  def tensorlize(data):
118
+ try:
119
+ img_paths, bboxes, img_ratios = zip(*data)
120
+ except ValueError as e:
121
+ logger.error(
122
+ ":rotating_light: This may be caused by using old cache or another version of YOLO's cache.\n"
123
+ ":rotating_light: Please clean the cache and try running again."
124
+ )
125
+ raise e
126
  max_box = max(bbox.size(0) for bbox in bboxes)
127
  padded_bbox_list = []
128
  for bbox in bboxes:
yolo/utils/logging_utils.py CHANGED
@@ -38,6 +38,7 @@ from typing_extensions import override
38
  from yolo.config.config import Config, YOLOLayer
39
  from yolo.model.yolo import YOLO
40
  from yolo.utils.logger import logger
 
41
  from yolo.utils.solver_utils import make_ap_table
42
 
43
 
@@ -97,7 +98,6 @@ class YOLORichProgressBar(RichProgressBar):
97
  )
98
  self.max_result = 0
99
  self.past_results.clear()
100
- self.progress.update(self.task_epoch, advance=-0.5)
101
 
102
  @override
103
  @rank_zero_only
@@ -255,6 +255,8 @@ def setup(cfg: Config):
255
 
256
  progress, loggers = [], []
257
 
 
 
258
  if quite:
259
  logger.setLevel(logging.ERROR)
260
  return progress, loggers, save_path
 
38
  from yolo.config.config import Config, YOLOLayer
39
  from yolo.model.yolo import YOLO
40
  from yolo.utils.logger import logger
41
+ from yolo.utils.model_utils import EMA
42
  from yolo.utils.solver_utils import make_ap_table
43
 
44
 
 
98
  )
99
  self.max_result = 0
100
  self.past_results.clear()
 
101
 
102
  @override
103
  @rank_zero_only
 
255
 
256
  progress, loggers = [], []
257
 
258
+ if hasattr(cfg.task, "ema") and cfg.task.ema.enable:
259
+ progress.append(EMA(cfg.task.ema.decay))
260
  if quite:
261
  logger.setLevel(logging.ERROR)
262
  return progress, loggers, save_path
yolo/utils/model_utils.py CHANGED
@@ -1,11 +1,16 @@
1
  import os
 
 
2
  from pathlib import Path
3
  from typing import List, Optional, Type, Union
4
 
5
  import torch
6
  import torch.distributed as dist
 
 
 
7
  from omegaconf import ListConfig
8
- from torch import Tensor
9
  from torch.optim import Optimizer
10
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
11
 
@@ -15,28 +20,48 @@ from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, transform_
15
  from yolo.utils.logger import logger
16
 
17
 
18
- class ExponentialMovingAverage:
19
- def __init__(self, model: torch.nn.Module, decay: float):
20
- self.model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  self.decay = decay
22
- self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}
 
23
 
24
- def update(self):
25
- """Update the shadow parameters using the current model parameters."""
26
- for name, param in self.model.named_parameters():
27
- assert name in self.shadow, "All model parameters should have a corresponding shadow parameter."
28
- new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
29
- self.shadow[name] = new_average.clone()
30
 
31
- def apply_shadow(self):
32
- """Apply the shadow parameters to the model."""
33
- for name, param in self.model.named_parameters():
34
- param.data.copy_(self.shadow[name])
35
 
36
- def restore(self):
37
- """Restore the original parameters from the shadow."""
38
- for name, param in self.model.named_parameters():
39
- self.shadow[name].copy_(param.data)
 
 
 
40
 
41
 
42
  def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
@@ -57,9 +82,15 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
57
  {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
58
  ]
59
 
60
- def next_epoch(self, batch_num):
61
  self.min_lr = self.max_lr
62
  self.max_lr = [param["lr"] for param in self.param_groups]
 
 
 
 
 
 
63
  self.batch_num = batch_num
64
  self.batch_idx = 0
65
 
@@ -68,7 +99,8 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
68
  lr_dict = dict()
69
  for lr_idx, param_group in enumerate(self.param_groups):
70
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
71
- param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
 
72
  lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
73
  return lr_dict
74
 
 
1
  import os
2
+ from copy import deepcopy
3
+ from math import exp
4
  from pathlib import Path
5
  from typing import List, Optional, Type, Union
6
 
7
  import torch
8
  import torch.distributed as dist
9
+ from lightning import LightningModule, Trainer
10
+ from lightning.pytorch.callbacks import Callback
11
+ from lightning.pytorch.utilities import rank_zero_only
12
  from omegaconf import ListConfig
13
+ from torch import Tensor, no_grad
14
  from torch.optim import Optimizer
15
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
16
 
 
20
  from yolo.utils.logger import logger
21
 
22
 
23
+ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
24
+ """
25
+ Linearly interpolates between start and end values.
26
+
27
+ Parameters:
28
+ start (float): The starting value.
29
+ end (float): The ending value.
30
+ step (int): The current step in the interpolation process.
31
+ total (int): The total number of steps.
32
+
33
+ Returns:
34
+ float: The interpolated value.
35
+ """
36
+ return start + (end - start) * step / total
37
+
38
+
39
+ class EMA(Callback):
40
+ def __init__(self, decay: float = 0.9999, tau: float = 500):
41
+ super().__init__()
42
+ logger.info(":chart_with_upwards_trend: Enable Model EMA")
43
  self.decay = decay
44
+ self.tau = tau
45
+ self.step = 0
46
 
47
+ def setup(self, trainer, pl_module, stage):
48
+ pl_module.ema = deepcopy(pl_module.model)
49
+ self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
50
+ for param in pl_module.ema.parameters():
51
+ param.requires_grad = False
 
52
 
53
+ def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
54
+ for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
55
+ param.data.copy_(ema_param)
56
+ trainer.strategy.broadcast(param)
57
 
58
+ @rank_zero_only
59
+ @no_grad()
60
+ def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
61
+ self.step += 1
62
+ decay_factor = self.decay * (1 - exp(-self.step / self.tau))
63
+ for param, ema_param in zip(pl_module.parameters(), self.ema_parameters):
64
+ ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor))
65
 
66
 
67
  def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
 
82
  {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
83
  ]
84
 
85
+ def next_epoch(self, batch_num, epoch_idx):
86
  self.min_lr = self.max_lr
87
  self.max_lr = [param["lr"] for param in self.param_groups]
88
+ # TODO: load momentum from config instead a fix number
89
+ # 0.937: Start Momentum
90
+ # 0.8 : Normal Momemtum
91
+ # 3 : The warm up epoch num
92
+ self.min_mom = lerp(0.937, 0.8, max(epoch_idx, 3), 3)
93
+ self.max_mom = lerp(0.937, 0.8, max(epoch_idx + 1, 3), 3)
94
  self.batch_num = batch_num
95
  self.batch_idx = 0
96
 
 
99
  lr_dict = dict()
100
  for lr_idx, param_group in enumerate(self.param_groups):
101
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
102
+ param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
103
+ param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
104
  lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
105
  return lr_dict
106