Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch.nn as nn | |
from mmengine.dist import get_world_size | |
from mmengine.logging import print_log | |
from mmengine.model import is_model_wrapper | |
from mmengine.optim import OptimWrapper | |
from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, | |
OPTIMIZERS) | |
class YOLOv5OptimizerConstructor: | |
"""YOLOv5 constructor for optimizers. | |
It has the following functions: | |
- divides the optimizer parameters into 3 groups: | |
Conv, Bias and BN | |
- support `weight_decay` parameter adaption based on | |
`batch_size_per_gpu` | |
Args: | |
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. | |
Positional fields are | |
- ``type``: class name of the OptimizerWrapper | |
- ``optimizer``: The configuration of optimizer. | |
Optional fields are | |
- any arguments of the corresponding optimizer wrapper type, | |
e.g., accumulative_counts, clip_grad, etc. | |
The positional fields of ``optimizer`` are | |
- `type`: class name of the optimizer. | |
Optional fields are | |
- any arguments of the corresponding optimizer type, e.g., | |
lr, weight_decay, momentum, etc. | |
paramwise_cfg (dict, optional): Parameter-wise options. Must include | |
`base_total_batch_size` if not None. If the total input batch | |
is smaller than `base_total_batch_size`, the `weight_decay` | |
parameter will be kept unchanged, otherwise linear scaling. | |
Example: | |
>>> model = torch.nn.modules.Conv1d(1, 1, 1) | |
>>> optim_wrapper_cfg = dict( | |
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01, | |
>>> momentum=0.9, weight_decay=0.0001, batch_size_per_gpu=16)) | |
>>> paramwise_cfg = dict(base_total_batch_size=64) | |
>>> optim_wrapper_builder = YOLOv5OptimizerConstructor( | |
>>> optim_wrapper_cfg, paramwise_cfg) | |
>>> optim_wrapper = optim_wrapper_builder(model) | |
""" | |
def __init__(self, | |
optim_wrapper_cfg: dict, | |
paramwise_cfg: Optional[dict] = None): | |
if paramwise_cfg is None: | |
paramwise_cfg = {'base_total_batch_size': 64} | |
assert 'base_total_batch_size' in paramwise_cfg | |
if not isinstance(optim_wrapper_cfg, dict): | |
raise TypeError('optimizer_cfg should be a dict', | |
f'but got {type(optim_wrapper_cfg)}') | |
assert 'optimizer' in optim_wrapper_cfg, ( | |
'`optim_wrapper_cfg` must contain "optimizer" config') | |
self.optim_wrapper_cfg = optim_wrapper_cfg | |
self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer') | |
self.base_total_batch_size = paramwise_cfg['base_total_batch_size'] | |
def __call__(self, model: nn.Module) -> OptimWrapper: | |
if is_model_wrapper(model): | |
model = model.module | |
optimizer_cfg = self.optimizer_cfg.copy() | |
weight_decay = optimizer_cfg.pop('weight_decay', 0) | |
if 'batch_size_per_gpu' in optimizer_cfg: | |
batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu') | |
# No scaling if total_batch_size is less than | |
# base_total_batch_size, otherwise linear scaling. | |
total_batch_size = get_world_size() * batch_size_per_gpu | |
accumulate = max( | |
round(self.base_total_batch_size / total_batch_size), 1) | |
scale_factor = total_batch_size * \ | |
accumulate / self.base_total_batch_size | |
if scale_factor != 1: | |
weight_decay *= scale_factor | |
print_log(f'Scaled weight_decay to {weight_decay}', 'current') | |
params_groups = [], [], [] | |
for v in model.modules(): | |
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): | |
params_groups[2].append(v.bias) | |
# Includes SyncBatchNorm | |
if isinstance(v, nn.modules.batchnorm._NormBase): | |
params_groups[1].append(v.weight) | |
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): | |
params_groups[0].append(v.weight) | |
# Note: Make sure bias is in the last parameter group | |
optimizer_cfg['params'] = [] | |
# conv | |
optimizer_cfg['params'].append({ | |
'params': params_groups[0], | |
'weight_decay': weight_decay | |
}) | |
# bn | |
optimizer_cfg['params'].append({'params': params_groups[1]}) | |
# bias | |
optimizer_cfg['params'].append({'params': params_groups[2]}) | |
print_log( | |
'Optimizer groups: %g .bias, %g conv.weight, %g other' % | |
(len(params_groups[2]), len(params_groups[0]), len( | |
params_groups[1])), 'current') | |
del params_groups | |
optimizer = OPTIMIZERS.build(optimizer_cfg) | |
optim_wrapper = OPTIM_WRAPPERS.build( | |
self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) | |
return optim_wrapper | |