henry000 commited on
Commit
9912678
Β·
2 Parent(s): 42eab9c 3e08dd8

πŸ”€ [Merge] branch 'TRAIN' into TEST

Browse files
examples/example_train.py CHANGED
@@ -9,8 +9,7 @@ project_root = Path(__file__).resolve().parent.parent
9
  sys.path.append(str(project_root))
10
 
11
  from yolo.config.config import Config
12
- from yolo.model.yolo import get_model
13
- from yolo.tools.log_helper import custom_logger
14
  from yolo.tools.trainer import Trainer
15
  from yolo.utils.dataloader import get_dataloader
16
  from yolo.utils.get_dataset import prepare_dataset
@@ -18,18 +17,17 @@ from yolo.utils.get_dataset import prepare_dataset
18
 
19
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
20
  def main(cfg: Config):
 
 
21
  if cfg.download.auto:
22
  prepare_dataset(cfg.download)
23
 
24
  dataloader = get_dataloader(cfg)
25
- model = get_model(cfg)
26
  # TODO: get_device or rank, for DDP mode
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
-
29
- trainer = Trainer(model, cfg, device)
30
- trainer.train(dataloader, 10)
31
 
32
 
33
  if __name__ == "__main__":
34
- custom_logger()
35
  main()
 
9
  sys.path.append(str(project_root))
10
 
11
  from yolo.config.config import Config
12
+ from yolo.tools.log_helper import custom_logger, get_valid_folder
 
13
  from yolo.tools.trainer import Trainer
14
  from yolo.utils.dataloader import get_dataloader
15
  from yolo.utils.get_dataset import prepare_dataset
 
17
 
18
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
19
  def main(cfg: Config):
20
+ custom_logger()
21
+ save_path = get_valid_folder(cfg.hyper.general, cfg.name)
22
  if cfg.download.auto:
23
  prepare_dataset(cfg.download)
24
 
25
  dataloader = get_dataloader(cfg)
 
26
  # TODO: get_device or rank, for DDP mode
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ trainer = Trainer(cfg, save_path, device)
29
+ trainer.train(dataloader, cfg.hyper.train.epoch)
 
30
 
31
 
32
  if __name__ == "__main__":
 
33
  main()
yolo/config/config.py CHANGED
@@ -25,11 +25,10 @@ class Download:
25
  @dataclass
26
  class DataLoaderConfig:
27
  batch_size: int
 
 
28
  shuffle: bool
29
- num_workers: int
30
  pin_memory: bool
31
- image_size: List[int]
32
- class_num: int
33
 
34
 
35
  @dataclass
@@ -54,6 +53,7 @@ class SchedulerArgs:
54
  class SchedulerConfig:
55
  type: str
56
  args: SchedulerArgs
 
57
 
58
 
59
  @dataclass
@@ -85,8 +85,22 @@ class TrainConfig:
85
  loss: LossConfig
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  @dataclass
89
  class HyperConfig:
 
90
  data: DataLoaderConfig
91
  train: TrainConfig
92
 
 
25
  @dataclass
26
  class DataLoaderConfig:
27
  batch_size: int
28
+ class_num: int
29
+ image_size: List[int]
30
  shuffle: bool
 
31
  pin_memory: bool
 
 
32
 
33
 
34
  @dataclass
 
53
  class SchedulerConfig:
54
  type: str
55
  args: SchedulerArgs
56
+ warmup: Dict[str, Union[str, int, float]]
57
 
58
 
59
  @dataclass
 
85
  loss: LossConfig
86
 
87
 
88
+ @dataclass
89
+ class GeneralConfig:
90
+ out_path: str
91
+ task: str
92
+ device: Union[str, int, List[int]]
93
+ cpu_num: int
94
+ use_wandb: bool
95
+ lucky_number: 10
96
+ exist_ok: bool
97
+ resume_train: bool
98
+ use_TensorBoard: bool
99
+
100
+
101
  @dataclass
102
  class HyperConfig:
103
+ general: GeneralConfig
104
  data: DataLoaderConfig
105
  train: TrainConfig
106
 
yolo/config/hyper/default.yaml CHANGED
@@ -1,17 +1,27 @@
 
 
 
 
 
 
 
 
 
 
1
  data:
2
  batch_size: 16
3
- shuffle: True
4
- num_workers: 16
5
- pin_memory: True
6
  class_num: 80
7
  image_size: [640, 640]
 
 
8
  train:
9
- epoch: 10
10
  optimizer:
11
- type: Adam
12
  args:
13
- lr: 0.001
14
- weight_decay: 0.0001
 
15
  loss:
16
  objective:
17
  BCELoss: 0.5
@@ -26,10 +36,13 @@ train:
26
  iou: 6.0
27
  cls: 0.5
28
  scheduler:
29
- type: StepLR
 
 
30
  args:
31
- step_size: 10
32
- gamma: 0.1
 
33
  ema:
34
  enabled: true
35
  decay: 0.995
 
1
+ general:
2
+ out_path: runs
3
+ task: train
4
+ deivce: [0]
5
+ cpu_num: 16
6
+ use_wandb: False
7
+ lucky_number: 10
8
+ exist_ok: True
9
+ resume_train: False
10
+ use_TensorBoard: False
11
  data:
12
  batch_size: 16
 
 
 
13
  class_num: 80
14
  image_size: [640, 640]
15
+ shuffle: True
16
+ pin_memory: True
17
  train:
18
+ epoch: 500
19
  optimizer:
20
+ type: SGD
21
  args:
22
+ lr: 0.01
23
+ weight_decay: 0.0005
24
+ momentum: 0.937
25
  loss:
26
  objective:
27
  BCELoss: 0.5
 
36
  iou: 6.0
37
  cls: 0.5
38
  scheduler:
39
+ type: LinearLR
40
+ warmup:
41
+ epochs: 3.0
42
  args:
43
+ total_iters: ${hyper.train.epoch}
44
+ start_factor: 1
45
+ end_factor: 0.01
46
  ema:
47
  enabled: true
48
  decay: 0.995
yolo/tools/log_helper.py CHANGED
@@ -11,6 +11,7 @@ Example:
11
  custom_logger()
12
  """
13
 
 
14
  import sys
15
  from typing import Dict, List
16
 
@@ -22,19 +23,20 @@ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
22
  from rich.table import Table
23
  from torch import Tensor
24
 
25
- from yolo.config.config import Config, YOLOLayer
26
 
27
 
28
  def custom_logger():
29
  logger.remove()
30
  logger.add(
31
  sys.stderr,
32
- format="<fg #003385>[{time:MM/DD HH:mm:ss}]</fg #003385><level>{level: ^8}</level>| <level>{message}</level>",
 
33
  )
34
 
35
 
36
  class CustomProgress:
37
- def __init__(self, cfg: Config, use_wandb: bool = False):
38
  self.progress = Progress(
39
  TextColumn("[progress.description]{task.description}"),
40
  BarColumn(bar_width=None),
@@ -44,18 +46,19 @@ class CustomProgress:
44
  self.use_wandb = use_wandb
45
  if self.use_wandb:
46
  wandb.errors.term._log = custom_wandb_log
47
- self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
 
 
48
 
49
  def start_train(self, num_epochs: int):
50
  self.task_epoch = self.progress.add_task("[cyan]Epochs [white]| Loss | Box | DFL | BCE |", total=num_epochs)
51
 
52
- def one_epoch(self):
53
- self.progress.update(self.task_epoch, advance=1)
54
-
55
- def finish_epoch(self):
56
- self.wandb.finish()
57
-
58
- def start_batch(self, num_batches):
59
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
60
 
61
  def one_batch(self, loss_dict: Dict[str, Tensor]):
@@ -69,15 +72,19 @@ class CustomProgress:
69
 
70
  self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
71
 
72
- def finish_batch(self):
73
  self.progress.remove_task(self.batch_task)
 
 
 
 
74
 
75
 
76
  def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
77
  if silent:
78
  return
79
  for line in string.split("\n"):
80
- logger.opt(raw=not newline).info("🌐 " + line)
81
 
82
 
83
  def log_model(model: List[YOLOLayer]):
@@ -99,3 +106,25 @@ def log_model(model: List[YOLOLayer]):
99
  channels = "-"
100
  table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
101
  console.print(table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  custom_logger()
12
  """
13
 
14
+ import os
15
  import sys
16
  from typing import Dict, List
17
 
 
23
  from rich.table import Table
24
  from torch import Tensor
25
 
26
+ from yolo.config.config import Config, GeneralConfig, YOLOLayer
27
 
28
 
29
  def custom_logger():
30
  logger.remove()
31
  logger.add(
32
  sys.stderr,
33
+ colorize=True,
34
+ format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
35
  )
36
 
37
 
38
  class CustomProgress:
39
+ def __init__(self, cfg: Config, save_path: str, use_wandb: bool = False):
40
  self.progress = Progress(
41
  TextColumn("[progress.description]{task.description}"),
42
  BarColumn(bar_width=None),
 
46
  self.use_wandb = use_wandb
47
  if self.use_wandb:
48
  wandb.errors.term._log = custom_wandb_log
49
+ self.wandb = wandb.init(
50
+ project="YOLO", resume="allow", mode="online", dir=save_path, id=None, name=cfg.name
51
+ )
52
 
53
  def start_train(self, num_epochs: int):
54
  self.task_epoch = self.progress.add_task("[cyan]Epochs [white]| Loss | Box | DFL | BCE |", total=num_epochs)
55
 
56
+ def start_one_epoch(self, num_batches, optimizer, epoch_idx):
57
+ if self.use_wandb:
58
+ lr_values = [params["lr"] for params in optimizer.param_groups]
59
+ lr_names = ["bias", "norm", "conv"]
60
+ for lr_name, lr_value in zip(lr_names, lr_values):
61
+ self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
 
62
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
63
 
64
  def one_batch(self, loss_dict: Dict[str, Tensor]):
 
72
 
73
  self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
74
 
75
+ def finish_one_epoch(self):
76
  self.progress.remove_task(self.batch_task)
77
+ self.progress.update(self.task_epoch, advance=1)
78
+
79
+ def finish_train(self):
80
+ self.wandb.finish()
81
 
82
 
83
  def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
84
  if silent:
85
  return
86
  for line in string.split("\n"):
87
+ logger.opt(raw=not newline, colors=True).info("🌐 " + line)
88
 
89
 
90
  def log_model(model: List[YOLOLayer]):
 
106
  channels = "-"
107
  table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
108
  console.print(table)
109
+
110
+
111
+ def get_valid_folder(general_cfg: GeneralConfig, exp_name):
112
+ base_path = os.path.join(general_cfg.out_path, general_cfg.task)
113
+ save_path = os.path.join(base_path, exp_name)
114
+
115
+ if not general_cfg.exist_ok:
116
+ index = 1
117
+ old_exp_name = exp_name
118
+ while os.path.isdir(save_path):
119
+ exp_name = f"{old_exp_name}{index}"
120
+ save_path = os.path.join(base_path, exp_name)
121
+ index += 1
122
+ if index > 1:
123
+ logger.opt(colors=True).warning(
124
+ f"πŸ”€ Experiment directory exists! Changed <red>{old_exp_name}</> to <green>{exp_name}</>"
125
+ )
126
+
127
+ os.makedirs(save_path, exist_ok=True)
128
+ logger.opt(colors=True).info(f"πŸ“„ Created log folder: <u><fg #808080>{save_path}</></>")
129
+ logger.add(os.path.join(save_path, "output.log"), backtrace=True, diagnose=True)
130
+ return save_path
yolo/tools/model_helper.py CHANGED
@@ -2,9 +2,10 @@ from typing import Any, Dict, Type
2
 
3
  import torch
4
  from torch.optim import Optimizer
5
- from torch.optim.lr_scheduler import _LRScheduler
6
 
7
  from yolo.config.config import OptimizerConfig, SchedulerConfig
 
8
 
9
 
10
  class EMA:
@@ -31,21 +32,38 @@ class EMA:
31
  self.shadow[name].copy_(param.data)
32
 
33
 
34
- def get_optimizer(model_parameters, optim_cfg: OptimizerConfig) -> Optimizer:
35
  """Create an optimizer for the given model parameters based on the configuration.
36
 
37
  Returns:
38
  An instance of the optimizer configured according to the provided settings.
39
  """
40
  optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
 
 
 
 
 
 
 
 
 
 
41
  return optimizer_class(model_parameters, **optim_cfg.args)
42
 
43
 
44
- def get_scheduler(optimizer: Optimizer, schedul_cfg: SchedulerConfig) -> _LRScheduler:
45
  """Create a learning rate scheduler for the given optimizer based on the configuration.
46
 
47
  Returns:
48
  An instance of the scheduler configured according to the provided settings.
49
  """
50
- scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedul_cfg.type)
51
- return scheduler_class(optimizer, **schedul_cfg.args)
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from torch.optim import Optimizer
5
+ from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
6
 
7
  from yolo.config.config import OptimizerConfig, SchedulerConfig
8
+ from yolo.model.yolo import YOLO
9
 
10
 
11
  class EMA:
 
32
  self.shadow[name].copy_(param.data)
33
 
34
 
35
+ def get_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
36
  """Create an optimizer for the given model parameters based on the configuration.
37
 
38
  Returns:
39
  An instance of the optimizer configured according to the provided settings.
40
  """
41
  optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
42
+
43
+ bias_params = [p for name, p in model.named_parameters() if "bias" in name]
44
+ norm_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" in name]
45
+ conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
46
+
47
+ model_parameters = [
48
+ {"params": bias_params, "nestrov": True, "momentum": 0.937},
49
+ {"params": conv_params, "weight_decay": 0.0},
50
+ {"params": norm_params, "weight_decay": 1e-5},
51
+ ]
52
  return optimizer_class(model_parameters, **optim_cfg.args)
53
 
54
 
55
+ def get_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LRScheduler:
56
  """Create a learning rate scheduler for the given optimizer based on the configuration.
57
 
58
  Returns:
59
  An instance of the scheduler configured according to the provided settings.
60
  """
61
+ scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedule_cfg.type)
62
+ schedule = scheduler_class(optimizer, **schedule_cfg.args)
63
+ if hasattr(schedule_cfg, "warmup"):
64
+ wepoch = schedule_cfg.warmup.epochs
65
+ lambda1 = lambda epoch: 0.1 + 0.9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
66
+ lambda2 = lambda epoch: 10 - 9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
67
+ warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
68
+ schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
69
+ return schedule
yolo/tools/trainer.py CHANGED
@@ -6,22 +6,23 @@ from torch import Tensor
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
- from yolo.model.yolo import YOLO
10
  from yolo.tools.log_helper import CustomProgress
11
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
12
  from yolo.utils.loss import get_loss_function
13
 
14
 
15
  class Trainer:
16
- def __init__(self, model: YOLO, cfg: Config, device):
17
  train_cfg: TrainConfig = cfg.hyper.train
 
18
 
19
  self.model = model.to(device)
20
  self.device = device
21
- self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
22
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
23
  self.loss_fn = get_loss_function(cfg)
24
- self.progress = CustomProgress(cfg, use_wandb=True)
25
 
26
  if getattr(train_cfg.ema, "enabled", False):
27
  self.ema = EMA(model, decay=train_cfg.ema.decay)
@@ -46,7 +47,6 @@ class Trainer:
46
  def train_one_epoch(self, dataloader):
47
  self.model.train()
48
  total_loss = 0
49
- self.progress.start_batch(len(dataloader))
50
 
51
  for data, targets in dataloader:
52
  loss, loss_each = self.train_one_batch(data, targets)
@@ -57,7 +57,6 @@ class Trainer:
57
  if self.scheduler:
58
  self.scheduler.step()
59
 
60
- self.progress.finish_batch()
61
  return total_loss / len(dataloader)
62
 
63
  def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
@@ -79,8 +78,9 @@ class Trainer:
79
  self.progress.start_train(num_epochs)
80
  for epoch in range(num_epochs):
81
 
 
82
  epoch_loss = self.train_one_epoch(dataloader)
83
- self.progress.one_epoch()
84
 
85
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
86
  if (epoch + 1) % 5 == 0:
 
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
+ from yolo.model.yolo import get_model
10
  from yolo.tools.log_helper import CustomProgress
11
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
12
  from yolo.utils.loss import get_loss_function
13
 
14
 
15
  class Trainer:
16
+ def __init__(self, cfg: Config, save_path: str, device):
17
  train_cfg: TrainConfig = cfg.hyper.train
18
+ model = get_model(cfg)
19
 
20
  self.model = model.to(device)
21
  self.device = device
22
+ self.optimizer = get_optimizer(model, train_cfg.optimizer)
23
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
24
  self.loss_fn = get_loss_function(cfg)
25
+ self.progress = CustomProgress(cfg, save_path, use_wandb=True)
26
 
27
  if getattr(train_cfg.ema, "enabled", False):
28
  self.ema = EMA(model, decay=train_cfg.ema.decay)
 
47
  def train_one_epoch(self, dataloader):
48
  self.model.train()
49
  total_loss = 0
 
50
 
51
  for data, targets in dataloader:
52
  loss, loss_each = self.train_one_batch(data, targets)
 
57
  if self.scheduler:
58
  self.scheduler.step()
59
 
 
60
  return total_loss / len(dataloader)
61
 
62
  def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
 
78
  self.progress.start_train(num_epochs)
79
  for epoch in range(num_epochs):
80
 
81
+ self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch)
82
  epoch_loss = self.train_one_epoch(dataloader)
83
+ self.progress.finish_one_epoch()
84
 
85
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
86
  if (epoch + 1) % 5 == 0:
yolo/utils/dataloader.py CHANGED
@@ -160,7 +160,7 @@ class YoloDataLoader(DataLoader):
160
  dataset,
161
  batch_size=hyper.batch_size,
162
  shuffle=hyper.shuffle,
163
- num_workers=hyper.num_workers,
164
  pin_memory=hyper.pin_memory,
165
  collate_fn=self.collate_fn,
166
  )
 
160
  dataset,
161
  batch_size=hyper.batch_size,
162
  shuffle=hyper.shuffle,
163
+ num_workers=config.hyper.general.cpu_num,
164
  pin_memory=hyper.pin_memory,
165
  collate_fn=self.collate_fn,
166
  )