Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,623 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from typing import Dict, List
import torch
class BaseOptimWrapper(metaclass=ABCMeta):
def __init__(self, optimizer):
self.optimizer = optimizer
# The Following code is used to initialize `base_param_settings`.
# `base_param_settings` is used to store the parameters that are not
# updated by the optimizer.
# The `base_param_settings` used for tracking the base learning in the
# optimizer. If the optimizer has multiple parameter groups, this
# params will not be scaled by the loss factor.
if len(optimizer.param_groups) > 1:
self.base_param_settings = {
'params': torch.tensor([0.0], dtype=torch.float)
}
self.base_param_settings.update(**self.optimizer.defaults)
else:
self.base_param_settings = None # type: ignore
@abstractmethod
def update_params(self, *args, **kwargs):
"""Update parameters in :attr:`optimizer`."""
@abstractmethod
def backward(self, loss: torch.Tensor, **kwargs) -> None:
"""Perform gradient back propagation."""
@abstractmethod
def zero_grad(self, **kwargs) -> None:
"""A wrapper of ``Optimizer.zero_grad``."""
@abstractmethod
def step(self, **kwargs):
"""Call the step method of optimizer."""
def state_dict(self) -> dict:
"""A wrapper of ``Optimizer.state_dict``."""
state_dict = self.optimizer.state_dict()
if self.base_param_settings is not None:
state_dict['base_param_settings'] = self.base_param_settings
return state_dict
def load_state_dict(self, state_dict: dict) -> None:
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
:attr:`optimizer`.
Provide unified ``load_state_dict`` interface compatible with automatic
mixed precision training. Subclass can overload this method to
implement the required logic. For example, the state dictionary of
GradScaler should be loaded when training with ``torch.cuda.amp``.
Args:
state_dict (dict): The state dictionary of :attr:`optimizer`.
"""
base_param_settings = state_dict.pop('base_param_settings', None)
if base_param_settings is not None:
self.base_param_settings = base_param_settings
# load state_dict of optimizer
self.optimizer.load_state_dict(state_dict)
@property
def param_groups(self) -> List[dict]:
"""A wrapper of ``Optimizer.param_groups``.
Make OptimizeWrapper compatible with :class:`_ParamScheduler`.
Returns:
dict: the ``param_groups`` of :attr:`optimizer`.
"""
if self.base_param_settings is not None:
return self.optimizer.param_groups + [self.base_param_settings]
else:
return self.optimizer.param_groups
@property
def defaults(self) -> dict:
"""A wrapper of ``Optimizer.defaults``.
Make OptimizeWrapper compatible with :class:`_ParamScheduler`.
Returns:
dict: the ``param_groups`` of :attr:`optimizer`.
"""
return self.optimizer.defaults
def get_lr(self):
"""Get the learning rate of the optimizer.
Provide unified interface to get learning rate of optimizer.
Returns:
Dict[str, List[float]]:
param_groups learning rate of the optimizer.
"""
res = {}
if self.base_param_settings is not None:
res['base_lr'] = [self.base_param_settings['lr']]
res['lr'] = [group['lr'] for group in self.optimizer.param_groups]
return res
def get_momentum(self) -> Dict[str, List[float]]:
"""Get the momentum of the optimizer.
Provide unified interface to get momentum of optimizer.
Returns:
Dict[str, List[float]]: Momentum of the optimizer.
"""
momentum = []
for group in self.optimizer.param_groups:
# Get momentum of SGD.
if 'momentum' in group.keys():
momentum.append(group['momentum'])
# Get momentum of Adam.
elif 'betas' in group.keys():
momentum.append(group['betas'][0])
else:
momentum.append(0)
return dict(momentum=momentum)
|