henry000 commited on
Commit
16c6705
·
1 Parent(s): 6aabc6c

✨ [Add] General config for global settings

Browse files
examples/example_train.py CHANGED
@@ -9,8 +9,7 @@ project_root = Path(__file__).resolve().parent.parent
9
  sys.path.append(str(project_root))
10
 
11
  from yolo.config.config import Config
12
- from yolo.model.yolo import get_model
13
- from yolo.tools.log_helper import custom_logger
14
  from yolo.tools.trainer import Trainer
15
  from yolo.utils.dataloader import get_dataloader
16
  from yolo.utils.get_dataset import prepare_dataset
@@ -18,18 +17,17 @@ from yolo.utils.get_dataset import prepare_dataset
18
 
19
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
20
  def main(cfg: Config):
 
 
21
  if cfg.download.auto:
22
  prepare_dataset(cfg.download)
23
 
24
  dataloader = get_dataloader(cfg)
25
- model = get_model(cfg)
26
  # TODO: get_device or rank, for DDP mode
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
-
29
- trainer = Trainer(model, cfg, device)
30
- trainer.train(dataloader, 10)
31
 
32
 
33
  if __name__ == "__main__":
34
- custom_logger()
35
  main()
 
9
  sys.path.append(str(project_root))
10
 
11
  from yolo.config.config import Config
12
+ from yolo.tools.log_helper import custom_logger, get_valid_folder
 
13
  from yolo.tools.trainer import Trainer
14
  from yolo.utils.dataloader import get_dataloader
15
  from yolo.utils.get_dataset import prepare_dataset
 
17
 
18
  @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
19
  def main(cfg: Config):
20
+ custom_logger()
21
+ save_path = get_valid_folder(cfg.hyper.general, cfg.name)
22
  if cfg.download.auto:
23
  prepare_dataset(cfg.download)
24
 
25
  dataloader = get_dataloader(cfg)
 
26
  # TODO: get_device or rank, for DDP mode
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ trainer = Trainer(cfg, save_path, device)
29
+ trainer.train(dataloader, cfg.hyper.train.epoch)
 
30
 
31
 
32
  if __name__ == "__main__":
 
33
  main()
yolo/config/config.py CHANGED
@@ -25,11 +25,10 @@ class Download:
25
  @dataclass
26
  class DataLoaderConfig:
27
  batch_size: int
 
 
28
  shuffle: bool
29
- num_workers: int
30
  pin_memory: bool
31
- image_size: List[int]
32
- class_num: int
33
 
34
 
35
  @dataclass
@@ -85,8 +84,22 @@ class TrainConfig:
85
  loss: LossConfig
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  @dataclass
89
  class HyperConfig:
 
90
  data: DataLoaderConfig
91
  train: TrainConfig
92
 
 
25
  @dataclass
26
  class DataLoaderConfig:
27
  batch_size: int
28
+ class_num: int
29
+ image_size: List[int]
30
  shuffle: bool
 
31
  pin_memory: bool
 
 
32
 
33
 
34
  @dataclass
 
84
  loss: LossConfig
85
 
86
 
87
+ @dataclass
88
+ class GeneralConfig:
89
+ out_path: str
90
+ task: str
91
+ device: Union[str, int, List[int]]
92
+ cpu_num: int
93
+ use_wandb: bool
94
+ lucky_number: 10
95
+ exist_ok: bool
96
+ resume_train: bool
97
+ use_TensorBoard: bool
98
+
99
+
100
  @dataclass
101
  class HyperConfig:
102
+ general: GeneralConfig
103
  data: DataLoaderConfig
104
  train: TrainConfig
105
 
yolo/config/hyper/default.yaml CHANGED
@@ -1,10 +1,19 @@
 
 
 
 
 
 
 
 
 
 
1
  data:
2
  batch_size: 16
3
- shuffle: True
4
- num_workers: 16
5
- pin_memory: True
6
  class_num: 80
7
  image_size: [640, 640]
 
 
8
  train:
9
  epoch: 10
10
  optimizer:
 
1
+ general:
2
+ out_path: runs
3
+ task: train
4
+ deivce: [0]
5
+ cpu_num: 16
6
+ use_wandb: False
7
+ lucky_number: 10
8
+ exist_ok: True
9
+ resume_train: False
10
+ use_TensorBoard: False
11
  data:
12
  batch_size: 16
 
 
 
13
  class_num: 80
14
  image_size: [640, 640]
15
+ shuffle: True
16
+ pin_memory: True
17
  train:
18
  epoch: 10
19
  optimizer:
yolo/tools/log_helper.py CHANGED
@@ -11,6 +11,7 @@ Example:
11
  custom_logger()
12
  """
13
 
 
14
  import sys
15
  from typing import Dict, List
16
 
@@ -22,19 +23,20 @@ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
22
  from rich.table import Table
23
  from torch import Tensor
24
 
25
- from yolo.config.config import Config, YOLOLayer
26
 
27
 
28
  def custom_logger():
29
  logger.remove()
30
  logger.add(
31
  sys.stderr,
32
- format="<fg #003385>[{time:MM/DD HH:mm:ss}]</fg #003385><level>{level: ^8}</level>| <level>{message}</level>",
 
33
  )
34
 
35
 
36
  class CustomProgress:
37
- def __init__(self, cfg: Config, use_wandb: bool = False):
38
  self.progress = Progress(
39
  TextColumn("[progress.description]{task.description}"),
40
  BarColumn(bar_width=None),
@@ -44,18 +46,19 @@ class CustomProgress:
44
  self.use_wandb = use_wandb
45
  if self.use_wandb:
46
  wandb.errors.term._log = custom_wandb_log
47
- self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
 
 
48
 
49
  def start_train(self, num_epochs: int):
50
  self.task_epoch = self.progress.add_task("[cyan]Epochs [white]| Loss | Box | DFL | BCE |", total=num_epochs)
51
 
52
- def one_epoch(self):
53
- self.progress.update(self.task_epoch, advance=1)
54
-
55
- def finish_epoch(self):
56
- self.wandb.finish()
57
-
58
- def start_batch(self, num_batches):
59
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
60
 
61
  def one_batch(self, loss_dict: Dict[str, Tensor]):
@@ -69,15 +72,19 @@ class CustomProgress:
69
 
70
  self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
71
 
72
- def finish_batch(self):
73
  self.progress.remove_task(self.batch_task)
 
 
 
 
74
 
75
 
76
  def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
77
  if silent:
78
  return
79
  for line in string.split("\n"):
80
- logger.opt(raw=not newline).info("🌐 " + line)
81
 
82
 
83
  def log_model(model: List[YOLOLayer]):
@@ -99,3 +106,25 @@ def log_model(model: List[YOLOLayer]):
99
  channels = "-"
100
  table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
101
  console.print(table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  custom_logger()
12
  """
13
 
14
+ import os
15
  import sys
16
  from typing import Dict, List
17
 
 
23
  from rich.table import Table
24
  from torch import Tensor
25
 
26
+ from yolo.config.config import Config, GeneralConfig, YOLOLayer
27
 
28
 
29
  def custom_logger():
30
  logger.remove()
31
  logger.add(
32
  sys.stderr,
33
+ colorize=True,
34
+ format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
35
  )
36
 
37
 
38
  class CustomProgress:
39
+ def __init__(self, cfg: Config, save_path: str, use_wandb: bool = False):
40
  self.progress = Progress(
41
  TextColumn("[progress.description]{task.description}"),
42
  BarColumn(bar_width=None),
 
46
  self.use_wandb = use_wandb
47
  if self.use_wandb:
48
  wandb.errors.term._log = custom_wandb_log
49
+ self.wandb = wandb.init(
50
+ project="YOLO", resume="allow", mode="online", dir=save_path, id=None, name=cfg.name
51
+ )
52
 
53
  def start_train(self, num_epochs: int):
54
  self.task_epoch = self.progress.add_task("[cyan]Epochs [white]| Loss | Box | DFL | BCE |", total=num_epochs)
55
 
56
+ def start_one_epoch(self, num_batches, optimizer, epoch_idx):
57
+ if self.use_wandb:
58
+ lr_values = [params["lr"] for params in optimizer.param_groups]
59
+ lr_names = ["bias", "norm", "conv"]
60
+ for lr_name, lr_value in zip(lr_names, lr_values):
61
+ self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
 
62
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
63
 
64
  def one_batch(self, loss_dict: Dict[str, Tensor]):
 
72
 
73
  self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
74
 
75
+ def finish_one_epoch(self):
76
  self.progress.remove_task(self.batch_task)
77
+ self.progress.update(self.task_epoch, advance=1)
78
+
79
+ def finish_train(self):
80
+ self.wandb.finish()
81
 
82
 
83
  def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
84
  if silent:
85
  return
86
  for line in string.split("\n"):
87
+ logger.opt(raw=not newline, colors=True).info("🌐 " + line)
88
 
89
 
90
  def log_model(model: List[YOLOLayer]):
 
106
  channels = "-"
107
  table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
108
  console.print(table)
109
+
110
+
111
+ def get_valid_folder(general_cfg: GeneralConfig, exp_name):
112
+ base_path = os.path.join(general_cfg.out_path, general_cfg.task)
113
+ save_path = os.path.join(base_path, exp_name)
114
+
115
+ if not general_cfg.exist_ok:
116
+ index = 1
117
+ old_exp_name = exp_name
118
+ while os.path.isdir(save_path):
119
+ exp_name = f"{old_exp_name}{index}"
120
+ save_path = os.path.join(base_path, exp_name)
121
+ index += 1
122
+ if index > 1:
123
+ logger.opt(colors=True).warning(
124
+ f"🔀 Experiment directory exists! Changed <red>{old_exp_name}</> to <green>{exp_name}</>"
125
+ )
126
+
127
+ os.makedirs(save_path, exist_ok=True)
128
+ logger.opt(colors=True).info(f"📄 Created log folder: <u><fg #808080>{save_path}</></>")
129
+ logger.add(os.path.join(save_path, "output.log"), backtrace=True, diagnose=True)
130
+ return save_path
yolo/tools/trainer.py CHANGED
@@ -6,22 +6,23 @@ from torch import Tensor
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
- from yolo.model.yolo import YOLO
10
  from yolo.tools.log_helper import CustomProgress
11
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
12
  from yolo.utils.loss import get_loss_function
13
 
14
 
15
  class Trainer:
16
- def __init__(self, model: YOLO, cfg: Config, device):
17
  train_cfg: TrainConfig = cfg.hyper.train
 
18
 
19
  self.model = model.to(device)
20
  self.device = device
21
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
22
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
23
  self.loss_fn = get_loss_function(cfg)
24
- self.progress = CustomProgress(cfg, use_wandb=True)
25
 
26
  if getattr(train_cfg.ema, "enabled", False):
27
  self.ema = EMA(model, decay=train_cfg.ema.decay)
 
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
+ from yolo.model.yolo import get_model
10
  from yolo.tools.log_helper import CustomProgress
11
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
12
  from yolo.utils.loss import get_loss_function
13
 
14
 
15
  class Trainer:
16
+ def __init__(self, cfg: Config, save_path: str, device):
17
  train_cfg: TrainConfig = cfg.hyper.train
18
+ model = get_model(cfg)
19
 
20
  self.model = model.to(device)
21
  self.device = device
22
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
23
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
24
  self.loss_fn = get_loss_function(cfg)
25
+ self.progress = CustomProgress(cfg, save_path, use_wandb=True)
26
 
27
  if getattr(train_cfg.ema, "enabled", False):
28
  self.ema = EMA(model, decay=train_cfg.ema.decay)
yolo/utils/dataloader.py CHANGED
@@ -160,7 +160,7 @@ class YoloDataLoader(DataLoader):
160
  dataset,
161
  batch_size=hyper.batch_size,
162
  shuffle=hyper.shuffle,
163
- num_workers=hyper.num_workers,
164
  pin_memory=hyper.pin_memory,
165
  collate_fn=self.collate_fn,
166
  )
 
160
  dataset,
161
  batch_size=hyper.batch_size,
162
  shuffle=hyper.shuffle,
163
+ num_workers=config.hyper.general.cpu_num,
164
  pin_memory=hyper.pin_memory,
165
  collate_fn=self.collate_fn,
166
  )