π¬ [Update] logging tools for quite mode
Browse files- yolo/model/yolo.py +1 -5
- yolo/tools/solver.py +10 -4
- yolo/utils/logging_utils.py +4 -2
yolo/model/yolo.py
CHANGED
@@ -6,10 +6,8 @@ from loguru import logger
|
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
from torch import nn
|
8 |
|
9 |
-
from yolo.config.config import
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
11 |
-
from yolo.tools.drawer import draw_model
|
12 |
-
from yolo.utils.logging_utils import log_model_structure
|
13 |
from yolo.utils.module_utils import get_layer_map
|
14 |
|
15 |
|
@@ -138,6 +136,4 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num:
|
|
138 |
model.model.load_state_dict(torch.load(weight_path), strict=False)
|
139 |
logger.info("β
Success load model weight")
|
140 |
|
141 |
-
log_model_structure(model.model)
|
142 |
-
draw_model(model=model)
|
143 |
return model
|
|
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
from torch import nn
|
8 |
|
9 |
+
from yolo.config.config import ModelConfig, YOLOLayer
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
|
|
|
11 |
from yolo.utils.module_utils import get_layer_map
|
12 |
|
13 |
|
|
|
136 |
model.model.load_state_dict(torch.load(weight_path), strict=False)
|
137 |
logger.info("β
Success load model weight")
|
138 |
|
|
|
|
|
139 |
return model
|
yolo/tools/solver.py
CHANGED
@@ -13,10 +13,10 @@ from torch.utils.data import DataLoader
|
|
13 |
from yolo.config.config import Config, TrainConfig, ValidationConfig
|
14 |
from yolo.model.yolo import YOLO
|
15 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
16 |
-
from yolo.tools.drawer import draw_bboxes
|
17 |
from yolo.tools.loss_functions import create_loss_function
|
18 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
19 |
-
from yolo.utils.logging_utils import ProgressLogger
|
20 |
from yolo.utils.model_utils import (
|
21 |
ExponentialMovingAverage,
|
22 |
create_optimizer,
|
@@ -25,7 +25,7 @@ from yolo.utils.model_utils import (
|
|
25 |
|
26 |
|
27 |
class ModelTrainer:
|
28 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
29 |
train_cfg: TrainConfig = cfg.task
|
30 |
self.model = model if not use_ddp else DDP(model, device_ids=[device])
|
31 |
self.use_ddp = use_ddp
|
@@ -37,7 +37,13 @@ class ModelTrainer:
|
|
37 |
self.progress = progress
|
38 |
self.num_epochs = cfg.task.epoch
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device)
|
42 |
|
43 |
if getattr(train_cfg.ema, "enabled", False):
|
|
|
13 |
from yolo.config.config import Config, TrainConfig, ValidationConfig
|
14 |
from yolo.model.yolo import YOLO
|
15 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
16 |
+
from yolo.tools.drawer import draw_bboxes, draw_model
|
17 |
from yolo.tools.loss_functions import create_loss_function
|
18 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
19 |
+
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
20 |
from yolo.utils.model_utils import (
|
21 |
ExponentialMovingAverage,
|
22 |
create_optimizer,
|
|
|
25 |
|
26 |
|
27 |
class ModelTrainer:
|
28 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device, use_ddp: bool):
|
29 |
train_cfg: TrainConfig = cfg.task
|
30 |
self.model = model if not use_ddp else DDP(model, device_ids=[device])
|
31 |
self.use_ddp = use_ddp
|
|
|
37 |
self.progress = progress
|
38 |
self.num_epochs = cfg.task.epoch
|
39 |
|
40 |
+
if not progress.quite_mode:
|
41 |
+
log_model_structure(model.model)
|
42 |
+
draw_model(model=model)
|
43 |
+
|
44 |
+
self.validation_dataloader = create_dataloader(
|
45 |
+
cfg.task.validation.data, cfg.dataset, cfg.task.validation.task, use_ddp
|
46 |
+
)
|
47 |
self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device)
|
48 |
|
49 |
if getattr(train_cfg.ema, "enabled", False):
|
yolo/utils/logging_utils.py
CHANGED
@@ -40,7 +40,9 @@ def custom_logger(quite: bool = False):
|
|
40 |
|
41 |
class ProgressLogger:
|
42 |
def __init__(self, cfg: Config, exp_name: str):
|
43 |
-
|
|
|
|
|
44 |
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
45 |
|
46 |
self.progress = Progress(
|
@@ -53,7 +55,7 @@ class ProgressLogger:
|
|
53 |
if self.use_wandb:
|
54 |
wandb.errors.term._log = custom_wandb_log
|
55 |
self.wandb = wandb.init(
|
56 |
-
project="YOLO", resume="allow", mode="online", dir=save_path, id=None, name=exp_name
|
57 |
)
|
58 |
|
59 |
def start_train(self, num_epochs: int):
|
|
|
40 |
|
41 |
class ProgressLogger:
|
42 |
def __init__(self, cfg: Config, exp_name: str):
|
43 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
44 |
+
self.quite_mode = local_rank or getattr(cfg, "quite", False)
|
45 |
+
custom_logger(self.quite_mode)
|
46 |
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
47 |
|
48 |
self.progress = Progress(
|
|
|
55 |
if self.use_wandb:
|
56 |
wandb.errors.term._log = custom_wandb_log
|
57 |
self.wandb = wandb.init(
|
58 |
+
project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
|
59 |
)
|
60 |
|
61 |
def start_train(self, num_epochs: int):
|