:truck: [Rename] Tracker to Logger, handle all log
Browse files- yolo/lazy.py +4 -5
- yolo/tools/solver.py +9 -11
- yolo/utils/logging_utils.py +6 -3
yolo/lazy.py
CHANGED
@@ -13,13 +13,12 @@ from yolo.tools.data_loader import create_dataloader
|
|
13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
-
from yolo.utils.logging_utils import
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
-
|
22 |
-
save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
23 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
24 |
device = torch.device(cfg.device)
|
25 |
if getattr(cfg.task, "fast_inference", False):
|
@@ -31,11 +30,11 @@ def main(cfg: Config):
|
|
31 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
32 |
|
33 |
if cfg.task.task == "train":
|
34 |
-
trainer = ModelTrainer(cfg, model, vec2box,
|
35 |
trainer.solve(dataloader)
|
36 |
|
37 |
if cfg.task.task == "inference":
|
38 |
-
tester = ModelTester(cfg, model, vec2box,
|
39 |
tester.solve(dataloader)
|
40 |
|
41 |
|
|
|
13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
+
from yolo.utils.logging_utils import ProgressLogger
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
+
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
|
|
22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
23 |
device = torch.device(cfg.device)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
|
|
30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
31 |
|
32 |
if cfg.task.task == "train":
|
33 |
+
trainer = ModelTrainer(cfg, model, vec2box, progress, device)
|
34 |
trainer.solve(dataloader)
|
35 |
|
36 |
if cfg.task.task == "inference":
|
37 |
+
tester = ModelTester(cfg, model, vec2box, progress, device)
|
38 |
tester.solve(dataloader)
|
39 |
|
40 |
|
yolo/tools/solver.py
CHANGED
@@ -13,7 +13,7 @@ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
|
13 |
from yolo.tools.drawer import draw_bboxes
|
14 |
from yolo.tools.loss_functions import create_loss_function
|
15 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
16 |
-
from yolo.utils.logging_utils import
|
17 |
from yolo.utils.model_utils import (
|
18 |
ExponentialMovingAverage,
|
19 |
create_optimizer,
|
@@ -22,7 +22,7 @@ from yolo.utils.model_utils import (
|
|
22 |
|
23 |
|
24 |
class ModelTrainer:
|
25 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box,
|
26 |
train_cfg: TrainConfig = cfg.task
|
27 |
self.model = model
|
28 |
self.vec2box = vec2box
|
@@ -30,11 +30,11 @@ class ModelTrainer:
|
|
30 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
31 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
32 |
self.loss_fn = create_loss_function(cfg, vec2box)
|
33 |
-
self.progress =
|
34 |
self.num_epochs = cfg.task.epoch
|
35 |
|
36 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
37 |
-
self.validator = ModelValidator(cfg.task.validation, model, vec2box,
|
38 |
|
39 |
if getattr(train_cfg.ema, "enabled", False):
|
40 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
@@ -101,14 +101,15 @@ class ModelTrainer:
|
|
101 |
|
102 |
|
103 |
class ModelTester:
|
104 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box,
|
105 |
self.model = model
|
106 |
self.device = device
|
107 |
self.vec2box = vec2box
|
108 |
-
self.progress =
|
109 |
|
110 |
self.nms = cfg.task.nms
|
111 |
-
self.save_path = save_path
|
|
|
112 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
113 |
self.idx2label = cfg.class_list
|
114 |
|
@@ -158,16 +159,13 @@ class ModelValidator:
|
|
158 |
validation_cfg: ValidationConfig,
|
159 |
model: YOLO,
|
160 |
vec2box: Vec2Box,
|
161 |
-
save_path: str,
|
162 |
device,
|
163 |
-
|
164 |
-
progress: ProgressTracker,
|
165 |
):
|
166 |
self.model = model
|
167 |
self.vec2box = vec2box
|
168 |
self.device = device
|
169 |
self.progress = progress
|
170 |
-
self.save_path = save_path
|
171 |
|
172 |
self.nms = validation_cfg.nms
|
173 |
|
|
|
13 |
from yolo.tools.drawer import draw_bboxes
|
14 |
from yolo.tools.loss_functions import create_loss_function
|
15 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
16 |
+
from yolo.utils.logging_utils import ProgressLogger
|
17 |
from yolo.utils.model_utils import (
|
18 |
ExponentialMovingAverage,
|
19 |
create_optimizer,
|
|
|
22 |
|
23 |
|
24 |
class ModelTrainer:
|
25 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
26 |
train_cfg: TrainConfig = cfg.task
|
27 |
self.model = model
|
28 |
self.vec2box = vec2box
|
|
|
30 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
31 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
32 |
self.loss_fn = create_loss_function(cfg, vec2box)
|
33 |
+
self.progress = progress
|
34 |
self.num_epochs = cfg.task.epoch
|
35 |
|
36 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
37 |
+
self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device, self.progress)
|
38 |
|
39 |
if getattr(train_cfg.ema, "enabled", False):
|
40 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
101 |
|
102 |
|
103 |
class ModelTester:
|
104 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
105 |
self.model = model
|
106 |
self.device = device
|
107 |
self.vec2box = vec2box
|
108 |
+
self.progress = progress
|
109 |
|
110 |
self.nms = cfg.task.nms
|
111 |
+
self.save_path = os.path.join(progress.save_path, "images")
|
112 |
+
os.makedirs(self.save_path, exist_ok=True)
|
113 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
114 |
self.idx2label = cfg.class_list
|
115 |
|
|
|
159 |
validation_cfg: ValidationConfig,
|
160 |
model: YOLO,
|
161 |
vec2box: Vec2Box,
|
|
|
162 |
device,
|
163 |
+
progress: ProgressLogger,
|
|
|
164 |
):
|
165 |
self.model = model
|
166 |
self.vec2box = vec2box
|
167 |
self.device = device
|
168 |
self.progress = progress
|
|
|
169 |
|
170 |
self.nms = validation_cfg.nms
|
171 |
|
yolo/utils/logging_utils.py
CHANGED
@@ -38,15 +38,18 @@ def custom_logger(quite: bool = False):
|
|
38 |
)
|
39 |
|
40 |
|
41 |
-
class
|
42 |
-
def __init__(self,
|
|
|
|
|
|
|
43 |
self.progress = Progress(
|
44 |
TextColumn("[progress.description]{task.description}"),
|
45 |
BarColumn(bar_width=None),
|
46 |
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
47 |
TimeRemainingColumn(),
|
48 |
)
|
49 |
-
self.use_wandb = use_wandb
|
50 |
if self.use_wandb:
|
51 |
wandb.errors.term._log = custom_wandb_log
|
52 |
self.wandb = wandb.init(
|
|
|
38 |
)
|
39 |
|
40 |
|
41 |
+
class ProgressLogger:
|
42 |
+
def __init__(self, cfg: Config, exp_name: str):
|
43 |
+
custom_logger(getattr(cfg, "quite", False))
|
44 |
+
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
45 |
+
|
46 |
self.progress = Progress(
|
47 |
TextColumn("[progress.description]{task.description}"),
|
48 |
BarColumn(bar_width=None),
|
49 |
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
50 |
TimeRemainingColumn(),
|
51 |
)
|
52 |
+
self.use_wandb = cfg.use_wandb
|
53 |
if self.use_wandb:
|
54 |
wandb.errors.term._log = custom_wandb_log
|
55 |
self.wandb = wandb.init(
|