Manan Goel
commited on
Commit
·
e578635
1
Parent(s):
1ee5e0e
feat(engine): support wandb logger (#1144)
Browse files- docs/quick_run.md +20 -0
- tools/train.py +7 -0
- yolox/core/trainer.py +33 -4
- yolox/utils/__init__.py +1 -1
- yolox/utils/logger.py +123 -1
docs/quick_run.md
CHANGED
@@ -56,6 +56,26 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
|
|
56 |
* --fp16: mixed precision training
|
57 |
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
**Multi Machine Training**
|
60 |
|
61 |
We also support multi-nodes training. Just add the following args:
|
|
|
56 |
* --fp16: mixed precision training
|
57 |
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
|
58 |
|
59 |
+
**Weights & Biases for Logging**
|
60 |
+
|
61 |
+
To use W&B for logging, install wandb in your environment and log in to your W&B account using
|
62 |
+
|
63 |
+
```shell
|
64 |
+
pip install wandb
|
65 |
+
wandb login
|
66 |
+
```
|
67 |
+
|
68 |
+
Log in to your W&B account
|
69 |
+
|
70 |
+
To start logging metrics to W&B during training add the flag `--logger` to the previous command and use the prefix "wandb-" to specify arguments for initializing the wandb run.
|
71 |
+
|
72 |
+
```shell
|
73 |
+
python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
|
74 |
+
yolox-m
|
75 |
+
yolox-l
|
76 |
+
yolox-x
|
77 |
+
```
|
78 |
+
|
79 |
**Multi Machine Training**
|
80 |
|
81 |
We also support multi-nodes training. Just add the following args:
|
tools/train.py
CHANGED
@@ -80,6 +80,13 @@ def make_parser():
|
|
80 |
action="store_true",
|
81 |
help="occupy GPU memory first for training.",
|
82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
parser.add_argument(
|
84 |
"opts",
|
85 |
help="Modify config options using the command-line",
|
|
|
80 |
action="store_true",
|
81 |
help="occupy GPU memory first for training.",
|
82 |
)
|
83 |
+
parser.add_argument(
|
84 |
+
"-l",
|
85 |
+
"--logger",
|
86 |
+
type=str,
|
87 |
+
help="Logger to be used for metrics",
|
88 |
+
default="tensorboard"
|
89 |
+
)
|
90 |
parser.add_argument(
|
91 |
"opts",
|
92 |
help="Modify config options using the command-line",
|
yolox/core/trainer.py
CHANGED
@@ -15,6 +15,7 @@ from yolox.data import DataPrefetcher
|
|
15 |
from yolox.utils import (
|
16 |
MeterBuffer,
|
17 |
ModelEMA,
|
|
|
18 |
all_reduce_norm,
|
19 |
get_local_rank,
|
20 |
get_model_info,
|
@@ -173,9 +174,18 @@ class Trainer:
|
|
173 |
self.evaluator = self.exp.get_evaluator(
|
174 |
batch_size=self.args.batch_size, is_distributed=self.is_distributed
|
175 |
)
|
176 |
-
# Tensorboard
|
177 |
if self.rank == 0:
|
178 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
logger.info("Training start...")
|
181 |
logger.info("\n{}".format(model))
|
@@ -184,6 +194,9 @@ class Trainer:
|
|
184 |
logger.info(
|
185 |
"Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
|
186 |
)
|
|
|
|
|
|
|
187 |
|
188 |
def before_epoch(self):
|
189 |
logger.info("---> start train epoch{}".format(self.epoch + 1))
|
@@ -246,6 +259,12 @@ class Trainer:
|
|
246 |
)
|
247 |
+ (", size: {:d}, {}".format(self.input_size[0], eta_str))
|
248 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
self.meter.clear_meters()
|
250 |
|
251 |
# random resizing
|
@@ -309,8 +328,15 @@ class Trainer:
|
|
309 |
|
310 |
self.model.train()
|
311 |
if self.rank == 0:
|
312 |
-
self.
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
logger.info("\n" + summary)
|
315 |
synchronize()
|
316 |
|
@@ -334,3 +360,6 @@ class Trainer:
|
|
334 |
self.file_name,
|
335 |
ckpt_name,
|
336 |
)
|
|
|
|
|
|
|
|
15 |
from yolox.utils import (
|
16 |
MeterBuffer,
|
17 |
ModelEMA,
|
18 |
+
WandbLogger,
|
19 |
all_reduce_norm,
|
20 |
get_local_rank,
|
21 |
get_model_info,
|
|
|
174 |
self.evaluator = self.exp.get_evaluator(
|
175 |
batch_size=self.args.batch_size, is_distributed=self.is_distributed
|
176 |
)
|
177 |
+
# Tensorboard and Wandb loggers
|
178 |
if self.rank == 0:
|
179 |
+
if self.args.logger == "tensorboard":
|
180 |
+
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
|
181 |
+
elif self.args.logger == "wandb":
|
182 |
+
wandb_params = dict()
|
183 |
+
for k, v in zip(self.args.opts[0::2], self.args.opts[1::2]):
|
184 |
+
if k.startswith("wandb-"):
|
185 |
+
wandb_params.update({k.lstrip("wandb-"): v})
|
186 |
+
self.wandb_logger = WandbLogger(config=vars(self.exp), **wandb_params)
|
187 |
+
else:
|
188 |
+
raise ValueError("logger must be either 'tensorboard' or 'wandb'")
|
189 |
|
190 |
logger.info("Training start...")
|
191 |
logger.info("\n{}".format(model))
|
|
|
194 |
logger.info(
|
195 |
"Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
|
196 |
)
|
197 |
+
if self.rank == 0:
|
198 |
+
if self.args.logger == "wandb":
|
199 |
+
self.wandb_logger.finish()
|
200 |
|
201 |
def before_epoch(self):
|
202 |
logger.info("---> start train epoch{}".format(self.epoch + 1))
|
|
|
259 |
)
|
260 |
+ (", size: {:d}, {}".format(self.input_size[0], eta_str))
|
261 |
)
|
262 |
+
|
263 |
+
if self.rank == 0:
|
264 |
+
if self.args.logger == "wandb":
|
265 |
+
self.wandb_logger.log_metrics({k: v.latest for k, v in loss_meter.items()})
|
266 |
+
self.wandb_logger.log_metrics({"lr": self.meter["lr"].latest})
|
267 |
+
|
268 |
self.meter.clear_meters()
|
269 |
|
270 |
# random resizing
|
|
|
328 |
|
329 |
self.model.train()
|
330 |
if self.rank == 0:
|
331 |
+
if self.args.logger == "tensorboard":
|
332 |
+
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
|
333 |
+
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
|
334 |
+
if self.args.logger == "wandb":
|
335 |
+
self.wandb_logger.log_metrics({
|
336 |
+
"val/COCOAP50": ap50,
|
337 |
+
"val/COCOAP50_95": ap50_95,
|
338 |
+
"epoch": self.epoch + 1,
|
339 |
+
})
|
340 |
logger.info("\n" + summary)
|
341 |
synchronize()
|
342 |
|
|
|
360 |
self.file_name,
|
361 |
ckpt_name,
|
362 |
)
|
363 |
+
|
364 |
+
if self.args.logger == "wandb":
|
365 |
+
self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)
|
yolox/utils/__init__.py
CHANGED
@@ -8,7 +8,7 @@ from .checkpoint import load_ckpt, save_checkpoint
|
|
8 |
from .demo_utils import *
|
9 |
from .dist import *
|
10 |
from .ema import *
|
11 |
-
from .logger import setup_logger
|
12 |
from .lr_scheduler import LRScheduler
|
13 |
from .metric import *
|
14 |
from .model_utils import *
|
|
|
8 |
from .demo_utils import *
|
9 |
from .dist import *
|
10 |
from .ema import *
|
11 |
+
from .logger import WandbLogger, setup_logger
|
12 |
from .lr_scheduler import LRScheduler
|
13 |
from .metric import *
|
14 |
from .model_utils import *
|
yolox/utils/logger.py
CHANGED
@@ -7,11 +7,14 @@ import os
|
|
7 |
import sys
|
8 |
from loguru import logger
|
9 |
|
|
|
|
|
10 |
|
11 |
def get_caller_name(depth=0):
|
12 |
"""
|
13 |
Args:
|
14 |
-
depth (int): Depth of caller conext, use 0 for caller depth.
|
|
|
15 |
|
16 |
Returns:
|
17 |
str: module name of the caller
|
@@ -93,3 +96,122 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
|
|
93 |
|
94 |
# redirect stdout/stderr to loguru
|
95 |
redirect_sys_output("INFO")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import sys
|
8 |
from loguru import logger
|
9 |
|
10 |
+
import torch
|
11 |
+
|
12 |
|
13 |
def get_caller_name(depth=0):
|
14 |
"""
|
15 |
Args:
|
16 |
+
depth (int): Depth of caller conext, use 0 for caller depth.
|
17 |
+
Default value: 0.
|
18 |
|
19 |
Returns:
|
20 |
str: module name of the caller
|
|
|
96 |
|
97 |
# redirect stdout/stderr to loguru
|
98 |
redirect_sys_output("INFO")
|
99 |
+
|
100 |
+
|
101 |
+
class WandbLogger(object):
|
102 |
+
"""
|
103 |
+
Log training runs, datasets, models, and predictions to Weights & Biases.
|
104 |
+
This logger sends information to W&B at wandb.ai.
|
105 |
+
By default, this information includes hyperparameters,
|
106 |
+
system configuration and metrics, model metrics,
|
107 |
+
and basic data metrics and analyses.
|
108 |
+
|
109 |
+
For more information, please refer to:
|
110 |
+
https://docs.wandb.ai/guides/track
|
111 |
+
"""
|
112 |
+
def __init__(self,
|
113 |
+
project=None,
|
114 |
+
name=None,
|
115 |
+
id=None,
|
116 |
+
entity=None,
|
117 |
+
save_dir=None,
|
118 |
+
config=None,
|
119 |
+
**kwargs):
|
120 |
+
"""
|
121 |
+
Args:
|
122 |
+
project (str): wandb project name.
|
123 |
+
name (str): wandb run name.
|
124 |
+
id (str): wandb run id.
|
125 |
+
entity (str): wandb entity name.
|
126 |
+
save_dir (str): save directory.
|
127 |
+
config (dict): config dict.
|
128 |
+
**kwargs: other kwargs.
|
129 |
+
"""
|
130 |
+
try:
|
131 |
+
import wandb
|
132 |
+
self.wandb = wandb
|
133 |
+
except ModuleNotFoundError:
|
134 |
+
raise ModuleNotFoundError(
|
135 |
+
"wandb is not installed."
|
136 |
+
"Please install wandb using pip install wandb"
|
137 |
+
)
|
138 |
+
|
139 |
+
self.project = project
|
140 |
+
self.name = name
|
141 |
+
self.id = id
|
142 |
+
self.save_dir = save_dir
|
143 |
+
self.config = config
|
144 |
+
self.kwargs = kwargs
|
145 |
+
self.entity = entity
|
146 |
+
self._run = None
|
147 |
+
self._wandb_init = dict(
|
148 |
+
project=self.project,
|
149 |
+
name=self.name,
|
150 |
+
id=self.id,
|
151 |
+
entity=self.entity,
|
152 |
+
dir=self.save_dir,
|
153 |
+
resume="allow"
|
154 |
+
)
|
155 |
+
self._wandb_init.update(**kwargs)
|
156 |
+
|
157 |
+
_ = self.run
|
158 |
+
|
159 |
+
if self.config:
|
160 |
+
self.run.config.update(self.config)
|
161 |
+
self.run.define_metric("epoch")
|
162 |
+
self.run.define_metric("val/", step_metric="epoch")
|
163 |
+
|
164 |
+
@property
|
165 |
+
def run(self):
|
166 |
+
if self._run is None:
|
167 |
+
if self.wandb.run is not None:
|
168 |
+
logger.info(
|
169 |
+
"There is a wandb run already in progress "
|
170 |
+
"and newly created instances of `WandbLogger` will reuse"
|
171 |
+
" this run. If this is not desired, call `wandb.finish()`"
|
172 |
+
"before instantiating `WandbLogger`."
|
173 |
+
)
|
174 |
+
self._run = self.wandb.run
|
175 |
+
else:
|
176 |
+
self._run = self.wandb.init(**self._wandb_init)
|
177 |
+
return self._run
|
178 |
+
|
179 |
+
def log_metrics(self, metrics, step=None):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
metrics (dict): metrics dict.
|
183 |
+
step (int): step number.
|
184 |
+
"""
|
185 |
+
|
186 |
+
for k, v in metrics.items():
|
187 |
+
if isinstance(v, torch.Tensor):
|
188 |
+
metrics[k] = v.item()
|
189 |
+
|
190 |
+
if step is not None:
|
191 |
+
self.run.log(metrics, step=step)
|
192 |
+
else:
|
193 |
+
self.run.log(metrics)
|
194 |
+
|
195 |
+
def save_checkpoint(self, save_dir, model_name, is_best):
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
save_dir (str): save directory.
|
199 |
+
model_name (str): model name.
|
200 |
+
is_best (bool): whether the model is the best model.
|
201 |
+
"""
|
202 |
+
filename = os.path.join(save_dir, model_name + "_ckpt.pth")
|
203 |
+
artifact = self.wandb.Artifact(
|
204 |
+
name=f"model-{self.run.id}",
|
205 |
+
type="model"
|
206 |
+
)
|
207 |
+
artifact.add_file(filename, name="model_ckpt.pth")
|
208 |
+
|
209 |
+
aliases = ["latest"]
|
210 |
+
|
211 |
+
if is_best:
|
212 |
+
aliases.append("best")
|
213 |
+
|
214 |
+
self.run.log_artifact(artifact, aliases=aliases)
|
215 |
+
|
216 |
+
def finish(self):
|
217 |
+
self.run.finish()
|