henry000 commited on
Commit
afa32b4
Β·
1 Parent(s): 1fe2937

πŸ’¬ [Update] logging tools for quite mode

Browse files
yolo/model/yolo.py CHANGED
@@ -6,10 +6,8 @@ from loguru import logger
6
  from omegaconf import ListConfig, OmegaConf
7
  from torch import nn
8
 
9
- from yolo.config.config import Config, ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
11
- from yolo.tools.drawer import draw_model
12
- from yolo.utils.logging_utils import log_model_structure
13
  from yolo.utils.module_utils import get_layer_map
14
 
15
 
@@ -138,6 +136,4 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num:
138
  model.model.load_state_dict(torch.load(weight_path), strict=False)
139
  logger.info("βœ… Success load model weight")
140
 
141
- log_model_structure(model.model)
142
- draw_model(model=model)
143
  return model
 
6
  from omegaconf import ListConfig, OmegaConf
7
  from torch import nn
8
 
9
+ from yolo.config.config import ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
 
 
11
  from yolo.utils.module_utils import get_layer_map
12
 
13
 
 
136
  model.model.load_state_dict(torch.load(weight_path), strict=False)
137
  logger.info("βœ… Success load model weight")
138
 
 
 
139
  return model
yolo/tools/solver.py CHANGED
@@ -13,10 +13,10 @@ from torch.utils.data import DataLoader
13
  from yolo.config.config import Config, TrainConfig, ValidationConfig
14
  from yolo.model.yolo import YOLO
15
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
16
- from yolo.tools.drawer import draw_bboxes
17
  from yolo.tools.loss_functions import create_loss_function
18
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
19
- from yolo.utils.logging_utils import ProgressLogger
20
  from yolo.utils.model_utils import (
21
  ExponentialMovingAverage,
22
  create_optimizer,
@@ -25,7 +25,7 @@ from yolo.utils.model_utils import (
25
 
26
 
27
  class ModelTrainer:
28
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
29
  train_cfg: TrainConfig = cfg.task
30
  self.model = model if not use_ddp else DDP(model, device_ids=[device])
31
  self.use_ddp = use_ddp
@@ -37,7 +37,13 @@ class ModelTrainer:
37
  self.progress = progress
38
  self.num_epochs = cfg.task.epoch
39
 
40
- self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
 
 
 
 
 
 
41
  self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device)
42
 
43
  if getattr(train_cfg.ema, "enabled", False):
 
13
  from yolo.config.config import Config, TrainConfig, ValidationConfig
14
  from yolo.model.yolo import YOLO
15
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
16
+ from yolo.tools.drawer import draw_bboxes, draw_model
17
  from yolo.tools.loss_functions import create_loss_function
18
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
19
+ from yolo.utils.logging_utils import ProgressLogger, log_model_structure
20
  from yolo.utils.model_utils import (
21
  ExponentialMovingAverage,
22
  create_optimizer,
 
25
 
26
 
27
  class ModelTrainer:
28
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device, use_ddp: bool):
29
  train_cfg: TrainConfig = cfg.task
30
  self.model = model if not use_ddp else DDP(model, device_ids=[device])
31
  self.use_ddp = use_ddp
 
37
  self.progress = progress
38
  self.num_epochs = cfg.task.epoch
39
 
40
+ if not progress.quite_mode:
41
+ log_model_structure(model.model)
42
+ draw_model(model=model)
43
+
44
+ self.validation_dataloader = create_dataloader(
45
+ cfg.task.validation.data, cfg.dataset, cfg.task.validation.task, use_ddp
46
+ )
47
  self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device)
48
 
49
  if getattr(train_cfg.ema, "enabled", False):
yolo/utils/logging_utils.py CHANGED
@@ -40,7 +40,9 @@ def custom_logger(quite: bool = False):
40
 
41
  class ProgressLogger:
42
  def __init__(self, cfg: Config, exp_name: str):
43
- custom_logger(getattr(cfg, "quite", False))
 
 
44
  self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
45
 
46
  self.progress = Progress(
@@ -53,7 +55,7 @@ class ProgressLogger:
53
  if self.use_wandb:
54
  wandb.errors.term._log = custom_wandb_log
55
  self.wandb = wandb.init(
56
- project="YOLO", resume="allow", mode="online", dir=save_path, id=None, name=exp_name
57
  )
58
 
59
  def start_train(self, num_epochs: int):
 
40
 
41
  class ProgressLogger:
42
  def __init__(self, cfg: Config, exp_name: str):
43
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
44
+ self.quite_mode = local_rank or getattr(cfg, "quite", False)
45
+ custom_logger(self.quite_mode)
46
  self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
47
 
48
  self.progress = Progress(
 
55
  if self.use_wandb:
56
  wandb.errors.term._log = custom_wandb_log
57
  self.wandb = wandb.init(
58
+ project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
59
  )
60
 
61
  def start_train(self, num_epochs: int):