AndreasLH's picture
upload repo
56bd2b5
raw
history blame
2.96 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from typing import Any, Dict, List, Set
from detectron2.solver.build import maybe_add_gradient_clipping
def build_optimizer(cfg, model):
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module in model.modules():
for key, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if isinstance(module, norm_module_types) and (cfg.SOLVER.WEIGHT_DECAY_NORM is not None):
weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM
elif key == "bias":
if (cfg.SOLVER.BIAS_LR_FACTOR is not None):
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
if (cfg.SOLVER.WEIGHT_DECAY_BIAS is not None):
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
# these params do not need weight decay at all
# TODO parameterize these in configs instead.
if key in ['priors_dims_per_cat', 'priors_z_scales', 'priors_z_stats']:
weight_decay = 0.0
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if cfg.SOLVER.TYPE == 'sgd':
optimizer = torch.optim.SGD(
params,
cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
weight_decay=cfg.SOLVER.WEIGHT_DECAY
)
elif cfg.SOLVER.TYPE == 'adam':
optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, eps=1e-02)
elif cfg.SOLVER.TYPE == 'adam+amsgrad':
optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02)
elif cfg.SOLVER.TYPE == 'adamw':
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, eps=1e-02)
elif cfg.SOLVER.TYPE == 'adamw+amsgrad':
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02)
else:
raise ValueError('{} is not supported as an optimizer.'.format(cfg.SOLVER.TYPE))
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
def freeze_bn(network):
for _, module in network.named_modules():
if isinstance(module, torch.nn.BatchNorm2d):
module.eval()
module.track_running_stats = False