|
|
|
import torch |
|
from typing import Any, Dict, List, Set |
|
from detectron2.solver.build import maybe_add_gradient_clipping |
|
|
|
def build_optimizer(cfg, model): |
|
norm_module_types = ( |
|
torch.nn.BatchNorm1d, |
|
torch.nn.BatchNorm2d, |
|
torch.nn.BatchNorm3d, |
|
torch.nn.SyncBatchNorm, |
|
torch.nn.GroupNorm, |
|
torch.nn.InstanceNorm1d, |
|
torch.nn.InstanceNorm2d, |
|
torch.nn.InstanceNorm3d, |
|
torch.nn.LayerNorm, |
|
torch.nn.LocalResponseNorm, |
|
) |
|
params: List[Dict[str, Any]] = [] |
|
memo: Set[torch.nn.parameter.Parameter] = set() |
|
for module in model.modules(): |
|
for key, value in module.named_parameters(recurse=False): |
|
if not value.requires_grad: |
|
continue |
|
|
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
|
|
lr = cfg.SOLVER.BASE_LR |
|
weight_decay = cfg.SOLVER.WEIGHT_DECAY |
|
|
|
if isinstance(module, norm_module_types) and (cfg.SOLVER.WEIGHT_DECAY_NORM is not None): |
|
weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM |
|
|
|
elif key == "bias": |
|
if (cfg.SOLVER.BIAS_LR_FACTOR is not None): |
|
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR |
|
if (cfg.SOLVER.WEIGHT_DECAY_BIAS is not None): |
|
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS |
|
|
|
|
|
|
|
if key in ['priors_dims_per_cat', 'priors_z_scales', 'priors_z_stats']: |
|
weight_decay = 0.0 |
|
|
|
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] |
|
|
|
if cfg.SOLVER.TYPE == 'sgd': |
|
optimizer = torch.optim.SGD( |
|
params, |
|
cfg.SOLVER.BASE_LR, |
|
momentum=cfg.SOLVER.MOMENTUM, |
|
nesterov=cfg.SOLVER.NESTEROV, |
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY |
|
) |
|
elif cfg.SOLVER.TYPE == 'adam': |
|
optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, eps=1e-02) |
|
elif cfg.SOLVER.TYPE == 'adam+amsgrad': |
|
optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02) |
|
elif cfg.SOLVER.TYPE == 'adamw': |
|
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, eps=1e-02) |
|
elif cfg.SOLVER.TYPE == 'adamw+amsgrad': |
|
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02) |
|
else: |
|
raise ValueError('{} is not supported as an optimizer.'.format(cfg.SOLVER.TYPE)) |
|
|
|
optimizer = maybe_add_gradient_clipping(cfg, optimizer) |
|
return optimizer |
|
|
|
def freeze_bn(network): |
|
|
|
for _, module in network.named_modules(): |
|
if isinstance(module, torch.nn.BatchNorm2d): |
|
module.eval() |
|
module.track_running_stats = False |
|
|