Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import torch | |
from typing import Any, Dict, List, Set | |
from detectron2.solver.build import maybe_add_gradient_clipping | |
def build_optimizer(cfg, model): | |
norm_module_types = ( | |
torch.nn.BatchNorm1d, | |
torch.nn.BatchNorm2d, | |
torch.nn.BatchNorm3d, | |
torch.nn.SyncBatchNorm, | |
torch.nn.GroupNorm, | |
torch.nn.InstanceNorm1d, | |
torch.nn.InstanceNorm2d, | |
torch.nn.InstanceNorm3d, | |
torch.nn.LayerNorm, | |
torch.nn.LocalResponseNorm, | |
) | |
params: List[Dict[str, Any]] = [] | |
memo: Set[torch.nn.parameter.Parameter] = set() | |
for module in model.modules(): | |
for key, value in module.named_parameters(recurse=False): | |
if not value.requires_grad: | |
continue | |
# Avoid duplicating parameters | |
if value in memo: | |
continue | |
memo.add(value) | |
lr = cfg.SOLVER.BASE_LR | |
weight_decay = cfg.SOLVER.WEIGHT_DECAY | |
if isinstance(module, norm_module_types) and (cfg.SOLVER.WEIGHT_DECAY_NORM is not None): | |
weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM | |
elif key == "bias": | |
if (cfg.SOLVER.BIAS_LR_FACTOR is not None): | |
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR | |
if (cfg.SOLVER.WEIGHT_DECAY_BIAS is not None): | |
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS | |
# these params do not need weight decay at all | |
# TODO parameterize these in configs instead. | |
if key in ['priors_dims_per_cat', 'priors_z_scales', 'priors_z_stats']: | |
weight_decay = 0.0 | |
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] | |
if cfg.SOLVER.TYPE == 'sgd': | |
optimizer = torch.optim.SGD( | |
params, | |
cfg.SOLVER.BASE_LR, | |
momentum=cfg.SOLVER.MOMENTUM, | |
nesterov=cfg.SOLVER.NESTEROV, | |
weight_decay=cfg.SOLVER.WEIGHT_DECAY | |
) | |
elif cfg.SOLVER.TYPE == 'adam': | |
optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, eps=1e-02) | |
elif cfg.SOLVER.TYPE == 'adam+amsgrad': | |
optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02) | |
elif cfg.SOLVER.TYPE == 'adamw': | |
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, eps=1e-02) | |
elif cfg.SOLVER.TYPE == 'adamw+amsgrad': | |
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02) | |
else: | |
raise ValueError('{} is not supported as an optimizer.'.format(cfg.SOLVER.TYPE)) | |
optimizer = maybe_add_gradient_clipping(cfg, optimizer) | |
return optimizer | |
def freeze_bn(network): | |
for _, module in network.named_modules(): | |
if isinstance(module, torch.nn.BatchNorm2d): | |
module.eval() | |
module.track_running_stats = False | |