Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import copy | |
import itertools | |
import logging | |
from collections import defaultdict | |
from enum import Enum | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union | |
import torch | |
from fvcore.common.param_scheduler import ( | |
CosineParamScheduler, | |
MultiStepParamScheduler, | |
StepWithFixedGammaParamScheduler, | |
) | |
from detectron2.config import CfgNode | |
from detectron2.utils.env import TORCH_VERSION | |
from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler | |
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] | |
_GradientClipper = Callable[[_GradientClipperInput], None] | |
class GradientClipType(Enum): | |
VALUE = "value" | |
NORM = "norm" | |
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: | |
""" | |
Creates gradient clipping closure to clip by value or by norm, | |
according to the provided config. | |
""" | |
cfg = copy.deepcopy(cfg) | |
def clip_grad_norm(p: _GradientClipperInput): | |
torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) | |
def clip_grad_value(p: _GradientClipperInput): | |
torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) | |
_GRADIENT_CLIP_TYPE_TO_CLIPPER = { | |
GradientClipType.VALUE: clip_grad_value, | |
GradientClipType.NORM: clip_grad_norm, | |
} | |
return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] | |
def _generate_optimizer_class_with_gradient_clipping( | |
optimizer: Type[torch.optim.Optimizer], | |
*, | |
per_param_clipper: Optional[_GradientClipper] = None, | |
global_clipper: Optional[_GradientClipper] = None, | |
) -> Type[torch.optim.Optimizer]: | |
""" | |
Dynamically creates a new type that inherits the type of a given instance | |
and overrides the `step` method to add gradient clipping | |
""" | |
assert ( | |
per_param_clipper is None or global_clipper is None | |
), "Not allowed to use both per-parameter clipping and global clipping" | |
def optimizer_wgc_step(self, closure=None): | |
if per_param_clipper is not None: | |
for group in self.param_groups: | |
for p in group["params"]: | |
per_param_clipper(p) | |
else: | |
# global clipper for future use with detr | |
# (https://github.com/facebookresearch/detr/pull/287) | |
all_params = itertools.chain(*[g["params"] for g in self.param_groups]) | |
global_clipper(all_params) | |
super(type(self), self).step(closure) | |
OptimizerWithGradientClip = type( | |
optimizer.__name__ + "WithGradientClip", | |
(optimizer,), | |
{"step": optimizer_wgc_step}, | |
) | |
return OptimizerWithGradientClip | |
def maybe_add_gradient_clipping( | |
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] | |
) -> Type[torch.optim.Optimizer]: | |
""" | |
If gradient clipping is enabled through config options, wraps the existing | |
optimizer type to become a new dynamically created class OptimizerWithGradientClip | |
that inherits the given optimizer and overrides the `step` method to | |
include gradient clipping. | |
Args: | |
cfg: CfgNode, configuration options | |
optimizer: type. A subclass of torch.optim.Optimizer | |
Return: | |
type: either the input `optimizer` (if gradient clipping is disabled), or | |
a subclass of it with gradient clipping included in the `step` method. | |
""" | |
if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: | |
return optimizer | |
if isinstance(optimizer, torch.optim.Optimizer): | |
optimizer_type = type(optimizer) | |
else: | |
assert issubclass(optimizer, torch.optim.Optimizer), optimizer | |
optimizer_type = optimizer | |
grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) | |
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( | |
optimizer_type, per_param_clipper=grad_clipper | |
) | |
if isinstance(optimizer, torch.optim.Optimizer): | |
optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended | |
return optimizer | |
else: | |
return OptimizerWithGradientClip | |
def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: | |
""" | |
Build an optimizer from config. | |
""" | |
params = get_default_optimizer_params( | |
model, | |
base_lr=cfg.SOLVER.BASE_LR, | |
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, | |
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, | |
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, | |
) | |
sgd_args = { | |
"params": params, | |
"lr": cfg.SOLVER.BASE_LR, | |
"momentum": cfg.SOLVER.MOMENTUM, | |
"nesterov": cfg.SOLVER.NESTEROV, | |
"weight_decay": cfg.SOLVER.WEIGHT_DECAY, | |
} | |
if TORCH_VERSION >= (1, 12): | |
sgd_args["foreach"] = True | |
return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args)) | |
def get_default_optimizer_params( | |
model: torch.nn.Module, | |
base_lr: Optional[float] = None, | |
weight_decay: Optional[float] = None, | |
weight_decay_norm: Optional[float] = None, | |
bias_lr_factor: Optional[float] = 1.0, | |
weight_decay_bias: Optional[float] = None, | |
lr_factor_func: Optional[Callable] = None, | |
overrides: Optional[Dict[str, Dict[str, float]]] = None, | |
) -> List[Dict[str, Any]]: | |
""" | |
Get default param list for optimizer, with support for a few types of | |
overrides. If no overrides needed, this is equivalent to `model.parameters()`. | |
Args: | |
base_lr: lr for every group by default. Can be omitted to use the one in optimizer. | |
weight_decay: weight decay for every group by default. Can be omitted to use the one | |
in optimizer. | |
weight_decay_norm: override weight decay for params in normalization layers | |
bias_lr_factor: multiplier of lr for bias parameters. | |
weight_decay_bias: override weight decay for bias parameters. | |
lr_factor_func: function to calculate lr decay rate by mapping the parameter names to | |
corresponding lr decay rate. Note that setting this option requires | |
also setting ``base_lr``. | |
overrides: if not `None`, provides values for optimizer hyperparameters | |
(LR, weight decay) for module parameters with a given name; e.g. | |
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and | |
weight decay values for all module parameters named `embedding`. | |
For common detection models, ``weight_decay_norm`` is the only option | |
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings | |
from Detectron1 that are not found useful. | |
Example: | |
:: | |
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), | |
lr=0.01, weight_decay=1e-4, momentum=0.9) | |
""" | |
if overrides is None: | |
overrides = {} | |
defaults = {} | |
if base_lr is not None: | |
defaults["lr"] = base_lr | |
if weight_decay is not None: | |
defaults["weight_decay"] = weight_decay | |
bias_overrides = {} | |
if bias_lr_factor is not None and bias_lr_factor != 1.0: | |
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters | |
# exactly the same as regular weights. | |
if base_lr is None: | |
raise ValueError("bias_lr_factor requires base_lr") | |
bias_overrides["lr"] = base_lr * bias_lr_factor | |
if weight_decay_bias is not None: | |
bias_overrides["weight_decay"] = weight_decay_bias | |
if len(bias_overrides): | |
if "bias" in overrides: | |
raise ValueError("Conflicting overrides for 'bias'") | |
overrides["bias"] = bias_overrides | |
if lr_factor_func is not None: | |
if base_lr is None: | |
raise ValueError("lr_factor_func requires base_lr") | |
norm_module_types = ( | |
torch.nn.BatchNorm1d, | |
torch.nn.BatchNorm2d, | |
torch.nn.BatchNorm3d, | |
torch.nn.SyncBatchNorm, | |
# NaiveSyncBatchNorm inherits from BatchNorm2d | |
torch.nn.GroupNorm, | |
torch.nn.InstanceNorm1d, | |
torch.nn.InstanceNorm2d, | |
torch.nn.InstanceNorm3d, | |
torch.nn.LayerNorm, | |
torch.nn.LocalResponseNorm, | |
) | |
params: List[Dict[str, Any]] = [] | |
memo: Set[torch.nn.parameter.Parameter] = set() | |
for module_name, module in model.named_modules(): | |
for module_param_name, value in module.named_parameters(recurse=False): | |
if not value.requires_grad: | |
continue | |
# Avoid duplicating parameters | |
if value in memo: | |
continue | |
memo.add(value) | |
hyperparams = copy.copy(defaults) | |
if isinstance(module, norm_module_types) and weight_decay_norm is not None: | |
hyperparams["weight_decay"] = weight_decay_norm | |
if lr_factor_func is not None: | |
hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}") | |
hyperparams.update(overrides.get(module_param_name, {})) | |
params.append({"params": [value], **hyperparams}) | |
return reduce_param_groups(params) | |
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
# Transform parameter groups into per-parameter structure. | |
# Later items in `params` can overwrite parameters set in previous items. | |
ret = defaultdict(dict) | |
for item in params: | |
assert "params" in item | |
cur_params = {x: y for x, y in item.items() if x != "params" and x != "param_names"} | |
if "param_names" in item: | |
for param_name, param in zip(item["param_names"], item["params"]): | |
ret[param].update({"param_names": [param_name], "params": [param], **cur_params}) | |
else: | |
for param in item["params"]: | |
ret[param].update({"params": [param], **cur_params}) | |
return list(ret.values()) | |
def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
# Reorganize the parameter groups and merge duplicated groups. | |
# The number of parameter groups needs to be as small as possible in order | |
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead | |
# of using a parameter_group per single parameter, we reorganize the | |
# parameter groups and merge duplicated groups. This approach speeds | |
# up multi-tensor optimizer significantly. | |
params = _expand_param_groups(params) | |
groups = defaultdict(list) # re-group all parameter groups by their hyperparams | |
for item in params: | |
cur_params = tuple((x, y) for x, y in item.items() if x != "params" and x != "param_names") | |
groups[cur_params].append({"params": item["params"]}) | |
if "param_names" in item: | |
groups[cur_params][-1]["param_names"] = item["param_names"] | |
ret = [] | |
for param_keys, param_values in groups.items(): | |
cur = {kv[0]: kv[1] for kv in param_keys} | |
cur["params"] = list( | |
itertools.chain.from_iterable([params["params"] for params in param_values]) | |
) | |
if len(param_values) > 0 and "param_names" in param_values[0]: | |
cur["param_names"] = list( | |
itertools.chain.from_iterable([params["param_names"] for params in param_values]) | |
) | |
ret.append(cur) | |
return ret | |
def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler: | |
""" | |
Build a LR scheduler from config. | |
""" | |
name = cfg.SOLVER.LR_SCHEDULER_NAME | |
if name == "WarmupMultiStepLR": | |
steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER] | |
if len(steps) != len(cfg.SOLVER.STEPS): | |
logger = logging.getLogger(__name__) | |
logger.warning( | |
"SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. " | |
"These values will be ignored." | |
) | |
sched = MultiStepParamScheduler( | |
values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)], | |
milestones=steps, | |
num_updates=cfg.SOLVER.MAX_ITER, | |
) | |
elif name == "WarmupCosineLR": | |
end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR | |
assert end_value >= 0.0 and end_value <= 1.0, end_value | |
sched = CosineParamScheduler(1, end_value) | |
elif name == "WarmupStepWithFixedGammaLR": | |
sched = StepWithFixedGammaParamScheduler( | |
base_value=1.0, | |
gamma=cfg.SOLVER.GAMMA, | |
num_decays=cfg.SOLVER.NUM_DECAYS, | |
num_updates=cfg.SOLVER.MAX_ITER, | |
) | |
else: | |
raise ValueError("Unknown LR scheduler: {}".format(name)) | |
sched = WarmupParamScheduler( | |
sched, | |
cfg.SOLVER.WARMUP_FACTOR, | |
min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0), | |
cfg.SOLVER.WARMUP_METHOD, | |
cfg.SOLVER.RESCALE_INTERVAL, | |
) | |
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER) | |