henry000 commited on
Commit
d58a9b6
·
1 Parent(s): 70a7f92

:truck: [Rename] Tracker to Logger, handle all log

Browse files
yolo/lazy.py CHANGED
@@ -13,13 +13,12 @@ from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
- from yolo.utils.logging_utils import custom_logger, validate_log_directory
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
- custom_logger()
22
- save_path = validate_log_directory(cfg, exp_name=cfg.name)
23
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
24
  device = torch.device(cfg.device)
25
  if getattr(cfg.task, "fast_inference", False):
@@ -31,11 +30,11 @@ def main(cfg: Config):
31
  vec2box = Vec2Box(model, cfg.image_size, device)
32
 
33
  if cfg.task.task == "train":
34
- trainer = ModelTrainer(cfg, model, vec2box, save_path, device)
35
  trainer.solve(dataloader)
36
 
37
  if cfg.task.task == "inference":
38
- tester = ModelTester(cfg, model, vec2box, save_path, device)
39
  tester.solve(dataloader)
40
 
41
 
 
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
+ from yolo.utils.logging_utils import ProgressLogger
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
+ progress = ProgressLogger(cfg, exp_name=cfg.name)
 
22
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
23
  device = torch.device(cfg.device)
24
  if getattr(cfg.task, "fast_inference", False):
 
30
  vec2box = Vec2Box(model, cfg.image_size, device)
31
 
32
  if cfg.task.task == "train":
33
+ trainer = ModelTrainer(cfg, model, vec2box, progress, device)
34
  trainer.solve(dataloader)
35
 
36
  if cfg.task.task == "inference":
37
+ tester = ModelTester(cfg, model, vec2box, progress, device)
38
  tester.solve(dataloader)
39
 
40
 
yolo/tools/solver.py CHANGED
@@ -13,7 +13,7 @@ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
13
  from yolo.tools.drawer import draw_bboxes
14
  from yolo.tools.loss_functions import create_loss_function
15
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
16
- from yolo.utils.logging_utils import ProgressTracker
17
  from yolo.utils.model_utils import (
18
  ExponentialMovingAverage,
19
  create_optimizer,
@@ -22,7 +22,7 @@ from yolo.utils.model_utils import (
22
 
23
 
24
  class ModelTrainer:
25
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
26
  train_cfg: TrainConfig = cfg.task
27
  self.model = model
28
  self.vec2box = vec2box
@@ -30,11 +30,11 @@ class ModelTrainer:
30
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
31
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
32
  self.loss_fn = create_loss_function(cfg, vec2box)
33
- self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
34
  self.num_epochs = cfg.task.epoch
35
 
36
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
37
- self.validator = ModelValidator(cfg.task.validation, model, vec2box, save_path, device, self.progress)
38
 
39
  if getattr(train_cfg.ema, "enabled", False):
40
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
@@ -101,14 +101,15 @@ class ModelTrainer:
101
 
102
 
103
  class ModelTester:
104
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
105
  self.model = model
106
  self.device = device
107
  self.vec2box = vec2box
108
- self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
109
 
110
  self.nms = cfg.task.nms
111
- self.save_path = save_path
 
112
  self.save_predict = getattr(cfg.task, "save_predict", None)
113
  self.idx2label = cfg.class_list
114
 
@@ -158,16 +159,13 @@ class ModelValidator:
158
  validation_cfg: ValidationConfig,
159
  model: YOLO,
160
  vec2box: Vec2Box,
161
- save_path: str,
162
  device,
163
- # TODO: think Progress?
164
- progress: ProgressTracker,
165
  ):
166
  self.model = model
167
  self.vec2box = vec2box
168
  self.device = device
169
  self.progress = progress
170
- self.save_path = save_path
171
 
172
  self.nms = validation_cfg.nms
173
 
 
13
  from yolo.tools.drawer import draw_bboxes
14
  from yolo.tools.loss_functions import create_loss_function
15
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
16
+ from yolo.utils.logging_utils import ProgressLogger
17
  from yolo.utils.model_utils import (
18
  ExponentialMovingAverage,
19
  create_optimizer,
 
22
 
23
 
24
  class ModelTrainer:
25
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
26
  train_cfg: TrainConfig = cfg.task
27
  self.model = model
28
  self.vec2box = vec2box
 
30
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
31
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
32
  self.loss_fn = create_loss_function(cfg, vec2box)
33
+ self.progress = progress
34
  self.num_epochs = cfg.task.epoch
35
 
36
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
37
+ self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device, self.progress)
38
 
39
  if getattr(train_cfg.ema, "enabled", False):
40
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
101
 
102
 
103
  class ModelTester:
104
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
105
  self.model = model
106
  self.device = device
107
  self.vec2box = vec2box
108
+ self.progress = progress
109
 
110
  self.nms = cfg.task.nms
111
+ self.save_path = os.path.join(progress.save_path, "images")
112
+ os.makedirs(self.save_path, exist_ok=True)
113
  self.save_predict = getattr(cfg.task, "save_predict", None)
114
  self.idx2label = cfg.class_list
115
 
 
159
  validation_cfg: ValidationConfig,
160
  model: YOLO,
161
  vec2box: Vec2Box,
 
162
  device,
163
+ progress: ProgressLogger,
 
164
  ):
165
  self.model = model
166
  self.vec2box = vec2box
167
  self.device = device
168
  self.progress = progress
 
169
 
170
  self.nms = validation_cfg.nms
171
 
yolo/utils/logging_utils.py CHANGED
@@ -38,15 +38,18 @@ def custom_logger(quite: bool = False):
38
  )
39
 
40
 
41
- class ProgressTracker:
42
- def __init__(self, exp_name: str, save_path: str, use_wandb: bool = False):
 
 
 
43
  self.progress = Progress(
44
  TextColumn("[progress.description]{task.description}"),
45
  BarColumn(bar_width=None),
46
  TextColumn("{task.completed:.0f}/{task.total:.0f}"),
47
  TimeRemainingColumn(),
48
  )
49
- self.use_wandb = use_wandb
50
  if self.use_wandb:
51
  wandb.errors.term._log = custom_wandb_log
52
  self.wandb = wandb.init(
 
38
  )
39
 
40
 
41
+ class ProgressLogger:
42
+ def __init__(self, cfg: Config, exp_name: str):
43
+ custom_logger(getattr(cfg, "quite", False))
44
+ self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
45
+
46
  self.progress = Progress(
47
  TextColumn("[progress.description]{task.description}"),
48
  BarColumn(bar_width=None),
49
  TextColumn("{task.completed:.0f}/{task.total:.0f}"),
50
  TimeRemainingColumn(),
51
  )
52
+ self.use_wandb = cfg.use_wandb
53
  if self.use_wandb:
54
  wandb.errors.term._log = custom_wandb_log
55
  self.wandb = wandb.init(