|
|
|
|
|
|
|
"""
|
|
This file contains components with some default boilerplate logic user may need
|
|
in training / testing. They will not work for everyone, but many users may find them useful.
|
|
|
|
The behavior of functions/classes in this file is subject to change,
|
|
since they are meant to represent the "common default behavior" people need in their projects.
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import weakref
|
|
from collections import OrderedDict
|
|
from typing import Optional
|
|
import torch
|
|
from fvcore.nn.precise_bn import get_bn_modules
|
|
from omegaconf import OmegaConf
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
import detectron2.data.transforms as T
|
|
from detectron2.checkpoint import DetectionCheckpointer
|
|
from detectron2.config import CfgNode, LazyConfig
|
|
from detectron2.data import (
|
|
MetadataCatalog,
|
|
build_detection_test_loader,
|
|
build_detection_train_loader,
|
|
)
|
|
from detectron2.evaluation import (
|
|
DatasetEvaluator,
|
|
inference_on_dataset,
|
|
print_csv_format,
|
|
verify_results,
|
|
)
|
|
from detectron2.modeling import build_model
|
|
from detectron2.solver import build_lr_scheduler, build_optimizer
|
|
from detectron2.utils import comm
|
|
from detectron2.utils.collect_env import collect_env_info
|
|
from detectron2.utils.env import seed_all_rng
|
|
from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
|
from detectron2.utils.file_io import PathManager
|
|
from detectron2.utils.logger import setup_logger
|
|
|
|
from . import hooks
|
|
from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
|
|
|
|
__all__ = [
|
|
"create_ddp_model",
|
|
"default_argument_parser",
|
|
"default_setup",
|
|
"default_writers",
|
|
"DefaultPredictor",
|
|
"DefaultTrainer",
|
|
]
|
|
|
|
|
|
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
|
"""
|
|
Create a DistributedDataParallel model if there are >1 processes.
|
|
|
|
Args:
|
|
model: a torch.nn.Module
|
|
fp16_compression: add fp16 compression hooks to the ddp object.
|
|
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
|
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
|
"""
|
|
if comm.get_world_size() == 1:
|
|
return model
|
|
if "device_ids" not in kwargs:
|
|
kwargs["device_ids"] = [comm.get_local_rank()]
|
|
ddp = DistributedDataParallel(model, **kwargs)
|
|
if fp16_compression:
|
|
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
|
|
|
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
|
return ddp
|
|
|
|
|
|
def default_argument_parser(epilog=None):
|
|
"""
|
|
Create a parser with some common arguments used by detectron2 users.
|
|
|
|
Args:
|
|
epilog (str): epilog passed to ArgumentParser describing the usage.
|
|
|
|
Returns:
|
|
argparse.ArgumentParser:
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
epilog=epilog
|
|
or f"""
|
|
Examples:
|
|
|
|
Run on single machine:
|
|
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
|
|
|
Change some config options:
|
|
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
|
|
|
Run on multiple machines:
|
|
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
|
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
|
""",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
|
parser.add_argument(
|
|
"--resume",
|
|
action="store_true",
|
|
help="Whether to attempt to resume from the checkpoint directory. "
|
|
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
|
|
)
|
|
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
|
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
|
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
|
|
parser.add_argument(
|
|
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
|
)
|
|
|
|
|
|
|
|
|
|
port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
|
|
parser.add_argument(
|
|
"--dist-url",
|
|
default="tcp://127.0.0.1:{}".format(port),
|
|
help="initialization URL for pytorch distributed backend. See "
|
|
"https://pytorch.org/docs/stable/distributed.html for details.",
|
|
)
|
|
parser.add_argument(
|
|
"opts",
|
|
help="""
|
|
Modify config options at the end of the command. For Yacs configs, use
|
|
space-separated "PATH.KEY VALUE" pairs.
|
|
For python-based LazyConfig, use "path.key=value".
|
|
""".strip(),
|
|
default=None,
|
|
nargs=argparse.REMAINDER,
|
|
)
|
|
return parser
|
|
|
|
|
|
def _try_get_key(cfg, *keys, default=None):
|
|
"""
|
|
Try select keys from cfg until the first key that exists. Otherwise return default.
|
|
"""
|
|
if isinstance(cfg, CfgNode):
|
|
cfg = OmegaConf.create(cfg.dump())
|
|
for k in keys:
|
|
none = object()
|
|
p = OmegaConf.select(cfg, k, default=none)
|
|
if p is not none:
|
|
return p
|
|
return default
|
|
|
|
|
|
def _highlight(code, filename):
|
|
try:
|
|
import pygments
|
|
except ImportError:
|
|
return code
|
|
|
|
from pygments.lexers import Python3Lexer, YamlLexer
|
|
from pygments.formatters import Terminal256Formatter
|
|
|
|
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
|
|
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
|
|
return code
|
|
|
|
|
|
def default_setup(cfg, args):
|
|
"""
|
|
Perform some basic common setups at the beginning of a job, including:
|
|
|
|
1. Set up the detectron2 logger
|
|
2. Log basic information about environment, cmdline arguments, and config
|
|
3. Backup the config to the output directory
|
|
|
|
Args:
|
|
cfg (CfgNode or omegaconf.DictConfig): the full config to be used
|
|
args (argparse.NameSpace): the command line arguments to be logged
|
|
"""
|
|
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
|
|
if comm.is_main_process() and output_dir:
|
|
PathManager.mkdirs(output_dir)
|
|
|
|
rank = comm.get_rank()
|
|
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
|
logger = setup_logger(output_dir, distributed_rank=rank)
|
|
|
|
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
|
logger.info("Environment info:\n" + collect_env_info())
|
|
|
|
logger.info("Command line arguments: " + str(args))
|
|
if hasattr(args, "config_file") and args.config_file != "":
|
|
logger.info(
|
|
"Contents of args.config_file={}:\n{}".format(
|
|
args.config_file,
|
|
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
|
|
)
|
|
)
|
|
|
|
if comm.is_main_process() and output_dir:
|
|
|
|
|
|
path = os.path.join(output_dir, "config.yaml")
|
|
if isinstance(cfg, CfgNode):
|
|
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
|
|
with PathManager.open(path, "w") as f:
|
|
f.write(cfg.dump())
|
|
else:
|
|
LazyConfig.save(cfg, path)
|
|
logger.info("Full config saved to {}".format(path))
|
|
|
|
|
|
seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
|
|
seed_all_rng(None if seed < 0 else seed + rank)
|
|
|
|
|
|
|
|
if not (hasattr(args, "eval_only") and args.eval_only):
|
|
torch.backends.cudnn.benchmark = _try_get_key(
|
|
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
|
|
)
|
|
|
|
|
|
def default_writers(output_dir: str, max_iter: Optional[int] = None):
|
|
"""
|
|
Build a list of :class:`EventWriter` to be used.
|
|
It now consists of a :class:`CommonMetricPrinter`,
|
|
:class:`TensorboardXWriter` and :class:`JSONWriter`.
|
|
|
|
Args:
|
|
output_dir: directory to store JSON metrics and tensorboard events
|
|
max_iter: the total number of iterations
|
|
|
|
Returns:
|
|
list[EventWriter]: a list of :class:`EventWriter` objects.
|
|
"""
|
|
PathManager.mkdirs(output_dir)
|
|
return [
|
|
|
|
CommonMetricPrinter(max_iter),
|
|
JSONWriter(os.path.join(output_dir, "metrics.json")),
|
|
TensorboardXWriter(output_dir),
|
|
]
|
|
|
|
|
|
class DefaultPredictor:
|
|
"""
|
|
Create a simple end-to-end predictor with the given config that runs on
|
|
single device for a single input image.
|
|
|
|
Compared to using the model directly, this class does the following additions:
|
|
|
|
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
|
|
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
|
|
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
|
|
4. Take one input image and produce a single output, instead of a batch.
|
|
|
|
This is meant for simple demo purposes, so it does the above steps automatically.
|
|
This is not meant for benchmarks or running complicated inference logic.
|
|
If you'd like to do anything more complicated, please refer to its source code as
|
|
examples to build and use the model manually.
|
|
|
|
Attributes:
|
|
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
|
cfg.DATASETS.TEST.
|
|
|
|
Examples:
|
|
::
|
|
pred = DefaultPredictor(cfg)
|
|
inputs = cv2.imread("input.jpg")
|
|
outputs = pred(inputs)
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
self.cfg = cfg.clone()
|
|
self.model = build_model(self.cfg)
|
|
self.model.eval()
|
|
if len(cfg.DATASETS.TEST):
|
|
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
|
|
|
checkpointer = DetectionCheckpointer(self.model)
|
|
checkpointer.load(cfg.MODEL.WEIGHTS)
|
|
|
|
self.aug = T.ResizeShortestEdge(
|
|
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
|
)
|
|
|
|
self.input_format = cfg.INPUT.FORMAT
|
|
assert self.input_format in ["RGB", "BGR"], self.input_format
|
|
|
|
def __call__(self, original_image):
|
|
"""
|
|
Args:
|
|
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
|
|
|
Returns:
|
|
predictions (dict):
|
|
the output of the model for one image only.
|
|
See :doc:`/tutorials/models` for details about the format.
|
|
"""
|
|
with torch.no_grad():
|
|
|
|
if self.input_format == "RGB":
|
|
|
|
original_image = original_image[:, :, ::-1]
|
|
height, width = original_image.shape[:2]
|
|
image = self.aug.get_transform(original_image).apply_image(original_image)
|
|
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
|
image.to(self.cfg.MODEL.DEVICE)
|
|
|
|
inputs = {"image": image, "height": height, "width": width}
|
|
|
|
predictions = self.model([inputs])[0]
|
|
return predictions
|
|
|
|
|
|
class DefaultTrainer(TrainerBase):
|
|
"""
|
|
A trainer with default training logic. It does the following:
|
|
|
|
1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
|
|
defined by the given config. Create a LR scheduler defined by the config.
|
|
2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
|
|
`resume_or_load` is called.
|
|
3. Register a few common hooks defined by the config.
|
|
|
|
It is created to simplify the **standard model training workflow** and reduce code boilerplate
|
|
for users who only need the standard training workflow, with standard features.
|
|
It means this class makes *many assumptions* about your training logic that
|
|
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
|
|
:class:`SimpleTrainer` are too much for research.
|
|
|
|
The code of this class has been annotated about restrictive assumptions it makes.
|
|
When they do not work for you, you're encouraged to:
|
|
|
|
1. Overwrite methods of this class, OR:
|
|
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
|
|
nothing else. You can then add your own hooks if needed. OR:
|
|
3. Write your own training loop similar to `tools/plain_train_net.py`.
|
|
|
|
See the :doc:`/tutorials/training` tutorials for more details.
|
|
|
|
Note that the behavior of this class, like other functions/classes in
|
|
this file, is not stable, since it is meant to represent the "common default behavior".
|
|
It is only guaranteed to work well with the standard models and training workflow in detectron2.
|
|
To obtain more stable behavior, write your own training logic with other public APIs.
|
|
|
|
Examples:
|
|
::
|
|
trainer = DefaultTrainer(cfg)
|
|
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
|
|
trainer.train()
|
|
|
|
Attributes:
|
|
scheduler:
|
|
checkpointer (DetectionCheckpointer):
|
|
cfg (CfgNode):
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
"""
|
|
Args:
|
|
cfg (CfgNode):
|
|
"""
|
|
super().__init__()
|
|
logger = logging.getLogger("detectron2")
|
|
if not logger.isEnabledFor(logging.INFO):
|
|
setup_logger()
|
|
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
|
|
|
|
|
|
model = self.build_model(cfg)
|
|
optimizer = self.build_optimizer(cfg, model)
|
|
data_loader = self.build_train_loader(cfg)
|
|
|
|
model = create_ddp_model(model, broadcast_buffers=False)
|
|
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
|
model, data_loader, optimizer
|
|
)
|
|
|
|
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
|
self.checkpointer = DetectionCheckpointer(
|
|
|
|
model,
|
|
cfg.OUTPUT_DIR,
|
|
trainer=weakref.proxy(self),
|
|
)
|
|
self.start_iter = 0
|
|
self.max_iter = cfg.SOLVER.MAX_ITER
|
|
self.cfg = cfg
|
|
|
|
self.register_hooks(self.build_hooks())
|
|
|
|
def resume_or_load(self, resume=True):
|
|
"""
|
|
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
|
|
a `last_checkpoint` file), resume from the file. Resuming means loading all
|
|
available states (eg. optimizer and scheduler) and update iteration counter
|
|
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
|
|
|
|
Otherwise, this is considered as an independent training. The method will load model
|
|
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
|
|
from iteration 0.
|
|
|
|
Args:
|
|
resume (bool): whether to do resume or not
|
|
"""
|
|
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
|
|
if resume and self.checkpointer.has_checkpoint():
|
|
|
|
|
|
self.start_iter = self.iter + 1
|
|
|
|
def build_hooks(self):
|
|
"""
|
|
Build a list of default hooks, including timing, evaluation,
|
|
checkpointing, lr scheduling, precise BN, writing events.
|
|
|
|
Returns:
|
|
list[HookBase]:
|
|
"""
|
|
cfg = self.cfg.clone()
|
|
cfg.defrost()
|
|
cfg.DATALOADER.NUM_WORKERS = 0
|
|
|
|
ret = [
|
|
hooks.IterationTimer(),
|
|
hooks.LRScheduler(),
|
|
hooks.PreciseBN(
|
|
|
|
cfg.TEST.EVAL_PERIOD,
|
|
self.model,
|
|
|
|
self.build_train_loader(cfg),
|
|
cfg.TEST.PRECISE_BN.NUM_ITER,
|
|
)
|
|
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
|
else None,
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
if comm.is_main_process():
|
|
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
|
|
|
def test_and_save_results():
|
|
self._last_eval_results = self.test(self.cfg, self.model)
|
|
return self._last_eval_results
|
|
|
|
|
|
|
|
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
|
|
|
if comm.is_main_process():
|
|
|
|
|
|
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
|
|
return ret
|
|
|
|
def build_writers(self):
|
|
"""
|
|
Build a list of writers to be used using :func:`default_writers()`.
|
|
If you'd like a different list of writers, you can overwrite it in
|
|
your trainer.
|
|
|
|
Returns:
|
|
list[EventWriter]: a list of :class:`EventWriter` objects.
|
|
"""
|
|
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
|
|
|
|
def train(self):
|
|
"""
|
|
Run training.
|
|
|
|
Returns:
|
|
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
|
"""
|
|
super().train(self.start_iter, self.max_iter)
|
|
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
|
|
assert hasattr(
|
|
self, "_last_eval_results"
|
|
), "No evaluation results obtained during training!"
|
|
verify_results(self.cfg, self._last_eval_results)
|
|
return self._last_eval_results
|
|
|
|
def run_step(self):
|
|
self._trainer.iter = self.iter
|
|
self._trainer.run_step()
|
|
|
|
def state_dict(self):
|
|
ret = super().state_dict()
|
|
ret["_trainer"] = self._trainer.state_dict()
|
|
return ret
|
|
|
|
def load_state_dict(self, state_dict):
|
|
super().load_state_dict(state_dict)
|
|
self._trainer.load_state_dict(state_dict["_trainer"])
|
|
|
|
@classmethod
|
|
def build_model(cls, cfg):
|
|
"""
|
|
Returns:
|
|
torch.nn.Module:
|
|
|
|
It now calls :func:`detectron2.modeling.build_model`.
|
|
Overwrite it if you'd like a different model.
|
|
"""
|
|
model = build_model(cfg)
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Model:\n{}".format(model))
|
|
return model
|
|
|
|
@classmethod
|
|
def build_optimizer(cls, cfg, model):
|
|
"""
|
|
Returns:
|
|
torch.optim.Optimizer:
|
|
|
|
It now calls :func:`detectron2.solver.build_optimizer`.
|
|
Overwrite it if you'd like a different optimizer.
|
|
"""
|
|
return build_optimizer(cfg, model)
|
|
|
|
@classmethod
|
|
def build_lr_scheduler(cls, cfg, optimizer):
|
|
"""
|
|
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
|
Overwrite it if you'd like a different scheduler.
|
|
"""
|
|
return build_lr_scheduler(cfg, optimizer)
|
|
|
|
@classmethod
|
|
def build_train_loader(cls, cfg):
|
|
"""
|
|
Returns:
|
|
iterable
|
|
|
|
It now calls :func:`detectron2.data.build_detection_train_loader`.
|
|
Overwrite it if you'd like a different data loader.
|
|
"""
|
|
return build_detection_train_loader(cfg)
|
|
|
|
@classmethod
|
|
def build_test_loader(cls, cfg, dataset_name):
|
|
"""
|
|
Returns:
|
|
iterable
|
|
|
|
It now calls :func:`detectron2.data.build_detection_test_loader`.
|
|
Overwrite it if you'd like a different data loader.
|
|
"""
|
|
return build_detection_test_loader(cfg, dataset_name)
|
|
|
|
@classmethod
|
|
def build_evaluator(cls, cfg, dataset_name):
|
|
"""
|
|
Returns:
|
|
DatasetEvaluator or None
|
|
|
|
It is not implemented by default.
|
|
"""
|
|
raise NotImplementedError(
|
|
"""
|
|
If you want DefaultTrainer to automatically run evaluation,
|
|
please implement `build_evaluator()` in subclasses (see train_net.py for example).
|
|
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
|
|
"""
|
|
)
|
|
|
|
@classmethod
|
|
def test(cls, cfg, model, evaluators=None):
|
|
"""
|
|
Evaluate the given model. The given model is expected to already contain
|
|
weights to evaluate.
|
|
|
|
Args:
|
|
cfg (CfgNode):
|
|
model (nn.Module):
|
|
evaluators (list[DatasetEvaluator] or None): if None, will call
|
|
:meth:`build_evaluator`. Otherwise, must have the same length as
|
|
``cfg.DATASETS.TEST``.
|
|
|
|
Returns:
|
|
dict: a dict of result metrics
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
if isinstance(evaluators, DatasetEvaluator):
|
|
evaluators = [evaluators]
|
|
if evaluators is not None:
|
|
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
|
len(cfg.DATASETS.TEST), len(evaluators)
|
|
)
|
|
|
|
results = OrderedDict()
|
|
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
|
data_loader = cls.build_test_loader(cfg, dataset_name)
|
|
|
|
|
|
if evaluators is not None:
|
|
evaluator = evaluators[idx]
|
|
else:
|
|
try:
|
|
evaluator = cls.build_evaluator(cfg, dataset_name)
|
|
except NotImplementedError:
|
|
logger.warn(
|
|
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
|
"or implement its `build_evaluator` method."
|
|
)
|
|
results[dataset_name] = {}
|
|
continue
|
|
results_i = inference_on_dataset(model, data_loader, evaluator)
|
|
results[dataset_name] = results_i
|
|
if comm.is_main_process():
|
|
assert isinstance(
|
|
results_i, dict
|
|
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
|
results_i
|
|
)
|
|
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
|
print_csv_format(results_i)
|
|
|
|
if len(results) == 1:
|
|
results = list(results.values())[0]
|
|
return results
|
|
|
|
@staticmethod
|
|
def auto_scale_workers(cfg, num_workers: int):
|
|
"""
|
|
When the config is defined for certain number of workers (according to
|
|
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
|
|
workers currently in use, returns a new cfg where the total batch size
|
|
is scaled so that the per-GPU batch size stays the same as the
|
|
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
|
|
|
|
Other config options are also scaled accordingly:
|
|
* training steps and warmup steps are scaled inverse proportionally.
|
|
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
|
|
|
|
For example, with the original config like the following:
|
|
|
|
.. code-block:: yaml
|
|
|
|
IMS_PER_BATCH: 16
|
|
BASE_LR: 0.1
|
|
REFERENCE_WORLD_SIZE: 8
|
|
MAX_ITER: 5000
|
|
STEPS: (4000,)
|
|
CHECKPOINT_PERIOD: 1000
|
|
|
|
When this config is used on 16 GPUs instead of the reference number 8,
|
|
calling this method will return a new config with:
|
|
|
|
.. code-block:: yaml
|
|
|
|
IMS_PER_BATCH: 32
|
|
BASE_LR: 0.2
|
|
REFERENCE_WORLD_SIZE: 16
|
|
MAX_ITER: 2500
|
|
STEPS: (2000,)
|
|
CHECKPOINT_PERIOD: 500
|
|
|
|
Note that both the original config and this new config can be trained on 16 GPUs.
|
|
It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
|
|
|
|
Returns:
|
|
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
|
|
"""
|
|
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
|
|
if old_world_size == 0 or old_world_size == num_workers:
|
|
return cfg
|
|
cfg = cfg.clone()
|
|
frozen = cfg.is_frozen()
|
|
cfg.defrost()
|
|
|
|
assert (
|
|
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
|
|
), "Invalid REFERENCE_WORLD_SIZE in config!"
|
|
scale = num_workers / old_world_size
|
|
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
|
|
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
|
|
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
|
|
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
|
|
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
|
|
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
|
|
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
|
|
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers
|
|
logger = logging.getLogger(__name__)
|
|
logger.info(
|
|
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
|
|
f"max_iter={max_iter}, warmup={warmup_iter}."
|
|
)
|
|
|
|
if frozen:
|
|
cfg.freeze()
|
|
return cfg
|
|
|
|
|
|
|
|
for _attr in ["model", "data_loader", "optimizer"]:
|
|
setattr(
|
|
DefaultTrainer,
|
|
_attr,
|
|
property(
|
|
|
|
lambda self, x=_attr: getattr(self._trainer, x),
|
|
|
|
lambda self, value, x=_attr: setattr(self._trainer, x, value),
|
|
),
|
|
)
|
|
|