Spaces:
Running
Running
import math | |
import functools | |
import warnings | |
from collections import OrderedDict, defaultdict | |
from copy import deepcopy | |
from itertools import chain | |
from typing import ( | |
Any, | |
Callable, | |
DefaultDict, | |
Dict, | |
Hashable, | |
Iterable, | |
List, | |
Optional, | |
Set, | |
Tuple, | |
TypeVar, | |
Union, | |
cast, | |
overload, | |
) | |
from typing_extensions import ParamSpec, Self, TypeAlias | |
import torch | |
import torch.utils.hooks as hooks | |
from torch.utils.hooks import RemovableHandle | |
from torch.utils._foreach_utils import ( | |
Indices, | |
TensorListList, | |
_get_foreach_kernels_supported_devices, | |
_get_fused_kernels_supported_devices, | |
) | |
from torch._utils import is_compiling | |
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype | |
Args: TypeAlias = Tuple[Any, ...] | |
Kwargs: TypeAlias = Dict[str, Any] | |
StateDict: TypeAlias = Dict[str, Any] | |
GlobalOptimizerPreHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]] | |
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] | |
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook'] | |
_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() | |
_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() | |
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] | |
class _RequiredParameter: | |
"""Singleton class representing a required parameter for an Optimizer.""" | |
def __repr__(self) -> str: | |
return "<required parameter>" | |
required = _RequiredParameter() | |
def _use_grad_for_differentiable(func): | |
def _use_grad(self, *args, **kwargs): | |
import torch._dynamo | |
prev_grad = torch.is_grad_enabled() | |
try: | |
# Note on graph break below: | |
# we need to graph break to ensure that aot respects the no_grad annotation. | |
# This is important for perf because without this, functionalization will generate an epilogue | |
# which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result, | |
# inductor will allocate for every parameter in the model, which is horrible. | |
# With this, aot correctly sees that this is an inference graph, and functionalization will generate | |
# an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that | |
# step is in place and is able to avoid the extra allocation. | |
# In the future, we will either 1) continue to graph break on backward, so this graph break does not matter | |
# or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this | |
# graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled. | |
# see https://github.com/pytorch/pytorch/issues/104053 | |
torch.set_grad_enabled(self.defaults['differentiable']) | |
torch._dynamo.graph_break() | |
ret = func(self, *args, **kwargs) | |
finally: | |
torch._dynamo.graph_break() | |
torch.set_grad_enabled(prev_grad) | |
return ret | |
functools.update_wrapper(_use_grad, func) | |
return _use_grad | |
def _get_value(x): | |
# item is significantly faster than a cpu tensor in eager mode | |
if not torch.jit.is_scripting() and is_compiling(): | |
return x | |
else: | |
return x.item() | |
def _stack_if_compiling(x): | |
if not torch.jit.is_scripting() and is_compiling(): | |
return torch.stack(x) | |
else: | |
return x | |
def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference | |
if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): | |
return x.sqrt() | |
else: | |
return math.sqrt(x) | |
# For any optimizer with a faster implementation, we attempt to default to the | |
# fastest + stablest whenever possible. For foreach, the requirements are to have | |
# native params all on CUDA. For fused, there's currently the additional requirement | |
# that the tensors' dtypes must be floating point. Neither alternative supports | |
# torch.jit.script nor differentiable, so we fall back to the single tensor | |
# implementation in those cases. | |
def _default_to_fused_or_foreach(params: List[torch.Tensor], | |
differentiable: bool, | |
use_fused: bool = False) -> Tuple[bool, bool]: | |
if torch.jit.is_scripting() or differentiable: | |
return False, False | |
fused_supported_devices = _get_fused_kernels_supported_devices() | |
foreach_supported_devices = _get_foreach_kernels_supported_devices() | |
fused = use_fused and all( | |
p is None or (type(p) in _foreach_supported_types and | |
p.device.type in fused_supported_devices and | |
torch.is_floating_point(p)) for p in params | |
) | |
foreach = not fused and all( | |
p is None or (type(p) in _foreach_supported_types and | |
p.device.type in foreach_supported_devices) for p in params | |
) | |
return fused, foreach | |
def _view_as_real(params, *state_and_grads): | |
for i, p in enumerate(params): | |
if torch.is_complex(p): | |
params[i] = torch.view_as_real(params[i]) | |
for s in state_and_grads: | |
s[i] = torch.view_as_real(s[i]) | |
def _get_scalar_dtype(is_fused=None): | |
if is_fused: | |
return torch.float32 | |
return torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 | |
# Common doc strings among optimizers | |
_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer | |
is used. If unspecified by the user (so foreach is None), we will try to use | |
foreach over the for-loop implementation on CUDA, since it is usually | |
significantly more performant. Note that the foreach implementation uses | |
~ sizeof(params) more peak memory than the for-loop version due to the intermediates | |
being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer | |
parameters through the optimizer at a time or switch this flag to False (default: None)""" | |
_fused_doc = r"""fused (bool, optional): whether the fused implementation (CUDA only) is used. | |
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` | |
are supported. (default: None) | |
.. note:: The foreach and fused implementations are typically faster than the for-loop, | |
single-tensor implementation. Thus, if the user has not specified BOTH flags | |
(i.e., when foreach = fused = None), we will attempt defaulting to the foreach | |
implementation when the tensors are all on CUDA. For example, if the user specifies | |
True for fused but nothing for foreach, we will run the fused implementation. If | |
the user specifies False for foreach but nothing for fused (or False for fused but | |
nothing for foreach), we will run the for-loop implementation. If the user specifies | |
True for both foreach and fused, we will prioritize fused over foreach, as it is | |
typically faster. We attempt to use the fastest, so the hierarchy goes fused -> | |
foreach -> for-loop. HOWEVER, since the fused implementation is relatively new, | |
we want to give it sufficient bake-in time, so we default to foreach and NOT | |
fused when the user has not specified either flag.""" | |
_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to | |
capture in a CUDA graph. Passing True can impair ungraphed performance, | |
so if you don't intend to graph capture this instance, leave it False | |
(default: False)""" | |
_differentiable_doc = r"""differentiable (bool, optional): whether autograd should | |
occur through the optimizer step in training. Otherwise, the step() | |
function runs in a torch.no_grad() context. Setting to True can impair | |
performance, so leave it False if you don't intend to run autograd | |
through this instance (default: False)""" | |
_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the | |
params, instead of minimizing (default: False)""" | |
def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: | |
r"""Register a pre hook common to all optimizers. The hook should have the following | |
signature:: | |
hook(optimizer, args, kwargs) -> None or modified args and kwargs | |
Args: | |
hook (Callable): A user defined hook which is registered on all optimizers. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) | |
_global_optimizer_pre_hooks[handle.id] = hook | |
return handle | |
def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: | |
r"""Register a post hook common to all optimizers. The hook should have the following | |
signature:: | |
hook(optimizer, args, kwargs) -> None | |
Args: | |
hook (Callable): A user defined hook which is registered on all optimizers. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(_global_optimizer_post_hooks) | |
_global_optimizer_post_hooks[handle.id] = hook | |
return handle | |
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] | |
_P = ParamSpec("_P") | |
R = TypeVar("R") | |
T = TypeVar("T") | |
class Optimizer: | |
r"""Base class for all optimizers. | |
.. warning:: | |
Parameters need to be specified as collections that have a deterministic | |
ordering that is consistent between runs. Examples of objects that don't | |
satisfy those properties are sets and iterators over values of dictionaries. | |
Args: | |
params (iterable): an iterable of :class:`torch.Tensor` s or | |
:class:`dict` s. Specifies what Tensors should be optimized. | |
defaults: (dict): a dict containing default values of optimization | |
options (used when a parameter group doesn't specify them). | |
""" | |
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] | |
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] | |
_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] | |
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook] | |
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' | |
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' | |
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' | |
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' | |
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: | |
torch._C._log_api_usage_once("python.optimizer") | |
self.defaults = defaults | |
self._optimizer_step_pre_hooks = OrderedDict() | |
self._optimizer_step_post_hooks = OrderedDict() | |
self._optimizer_state_dict_pre_hooks = OrderedDict() | |
self._optimizer_state_dict_post_hooks = OrderedDict() | |
self._optimizer_load_state_dict_pre_hooks = OrderedDict() | |
self._optimizer_load_state_dict_post_hooks = OrderedDict() | |
self._patch_step_function() | |
if isinstance(params, torch.Tensor): | |
if self.__class__.__name__ == 'SparseAdam': | |
warnings.warn(("Passing in a raw Tensor as ``params`` to SparseAdam " | |
"is deprecated. In the future, this will raise an error. " | |
"Please wrap your Tensor in an iterable instead."), | |
FutureWarning) | |
params = [params] | |
else: | |
raise TypeError("params argument given to the optimizer should be " | |
"an iterable of Tensors or dicts, but got " + | |
torch.typename(params)) | |
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict) | |
self.param_groups: List[Dict[str, Any]] = [] | |
param_groups = list(params) | |
if len(param_groups) == 0: | |
raise ValueError("optimizer got an empty parameter list") | |
if not isinstance(param_groups[0], dict): | |
param_groups = [{'params': param_groups}] | |
for param_group in param_groups: | |
self.add_param_group(cast(dict, param_group)) | |
# Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, | |
# which I don't think exists | |
# https://github.com/pytorch/pytorch/issues/72948 | |
self._warned_capturable_if_run_uncaptured = True | |
def __getstate__(self) -> Dict[str, Any]: | |
return { | |
'defaults': self.defaults, | |
'state': self.state, | |
'param_groups': self.param_groups, | |
} | |
def __setstate__(self, state: Dict[str, Any]) -> None: | |
self.__dict__.update(state) | |
if '_optimizer_step_pre_hooks' not in self.__dict__: | |
self._optimizer_step_pre_hooks = OrderedDict() | |
if '_optimizer_step_post_hooks' not in self.__dict__: | |
self._optimizer_step_post_hooks = OrderedDict() | |
if '_optimizer_state_dict_pre_hooks' not in self.__dict__: | |
self._optimizer_state_dict_pre_hooks = OrderedDict() | |
if '_optimizer_state_dict_post_hooks' not in self.__dict__: | |
self._optimizer_state_dict_post_hooks = OrderedDict() | |
if '_optimizer_load_state_dict_pre_hooks' not in self.__dict__: | |
self._optimizer_load_state_dict_pre_hooks = OrderedDict() | |
if '_optimizer_load_state_dict_post_hooks' not in self.__dict__: | |
self._optimizer_load_state_dict_post_hooks = OrderedDict() | |
self._patch_step_function() # To support multiprocessing pickle/unpickle | |
self.defaults.setdefault('differentiable', False) | |
def __repr__(self) -> str: | |
format_string = self.__class__.__name__ + ' (' | |
for i, group in enumerate(self.param_groups): | |
format_string += '\n' | |
format_string += f'Parameter Group {i}\n' | |
for key in sorted(group.keys()): | |
if key != 'params': | |
format_string += f' {key}: {group[key]}\n' | |
format_string += ')' | |
return format_string | |
# Currently needed by Adam and AdamW | |
def _cuda_graph_capture_health_check(self) -> None: | |
# Note [torch.compile x capturable] | |
# If we are compiling, we try to take the capturable path automatically by | |
# setting the flag to True during tracing. Due to this, we skip all the checks | |
# normally required for determining whether we can use CUDA graphs and | |
# shunt the responsibility to torch.inductor. This saves time during tracing | |
# since the checks are slow without sacrificing UX since inductor will warn | |
# later if CUDA graphs cannot be enabled, e.g., | |
# https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390. | |
# Thus, when compiling, inductor will determine if cudagraphs | |
# can be enabled based on whether there is input mutation or CPU tensors. | |
if not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available(): | |
capturing = torch.cuda.is_current_stream_capturing() | |
if capturing and not all(group['capturable'] for group in self.param_groups): | |
raise RuntimeError("Attempting CUDA graph capture of step() for an instance of " + | |
self.__class__.__name__ + | |
" but param_groups' capturable is False.") | |
if ( | |
(not getattr(self, "_warned_capturable_if_run_uncaptured", False)) | |
and all(group['capturable'] for group in self.param_groups) | |
and (not capturing) | |
): | |
warnings.warn( | |
"This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " | |
"but step() is running without CUDA graph capture. If you never intend to graph-capture this " | |
"instance, capturable=True can impair performance, and you should set capturable=False." | |
) | |
self._warned_capturable_if_run_uncaptured = True | |
def _optimizer_step_code(self) -> None: | |
"""Entry point for `torch.profile.profiler`. | |
When python tracing is enabled the profiler will hook into this | |
function at the CPython level to inspect the optimizer's parameters and | |
param groups. It is called it after `step()` since many optimizers | |
lazily initialize state. | |
This is a workaround due to lack of a proper step hook on the optimizer, | |
and will be removed if it exists. | |
""" | |
pass | |
def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: | |
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: | |
self, *_ = args | |
self = cast(Optimizer, self) | |
profile_name = f"Optimizer.step#{self.__class__.__name__}.step" | |
with torch.autograd.profiler.record_function(profile_name): | |
# call optimizer step pre hooks | |
for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()): | |
result = pre_hook(self, args, kwargs) | |
if result is not None: | |
if isinstance(result, tuple) and len(result) == 2: | |
args, kwargs = result # type: ignore[assignment] | |
else: | |
raise RuntimeError( | |
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." | |
) | |
out = func(*args, **kwargs) | |
self._optimizer_step_code() | |
# call optimizer step post hooks | |
for post_hook in chain(self._optimizer_step_post_hooks.values(), _global_optimizer_post_hooks.values()): | |
post_hook(self, args, kwargs) | |
return out | |
return wrapper | |
def _group_tensors_by_device_and_dtype( | |
tensorlistlist: TensorListList, | |
with_indices: bool = False, | |
) -> Union[ | |
Dict[Tuple[None, None], Tuple[TensorListList, Indices]], | |
Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]], | |
]: | |
"""Groups a list of lists of tensors by device and dtype. | |
Skips this step if we are compiling since this will occur during inductor lowering.""" | |
if is_compiling(): | |
return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} | |
else: | |
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) | |
def _patch_step_function(self) -> None: | |
self._zero_grad_profile_name = f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad" | |
hooked = getattr(self.__class__.step, "hooked", None) | |
if not hooked: | |
self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment] | |
self.__class__.step.hooked = True # type: ignore[attr-defined] | |
def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: | |
r"""Register an optimizer step pre hook which will be called before | |
optimizer step. It should have the following signature:: | |
hook(optimizer, args, kwargs) -> None or modified args and kwargs | |
The ``optimizer`` argument is the optimizer instance being used. If | |
args and kwargs are modified by the pre-hook, then the transformed | |
values are returned as a tuple containing the new_args and new_kwargs. | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) | |
self._optimizer_step_pre_hooks[handle.id] = hook | |
return handle | |
def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: | |
r"""Register an optimizer step post hook which will be called after optimizer step. | |
It should have the following signature:: | |
hook(optimizer, args, kwargs) -> None | |
The ``optimizer`` argument is the optimizer instance being used. | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
Returns: | |
:class:`torch.utils.hooks.RemovableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) | |
self._optimizer_step_post_hooks[handle.id] = hook | |
return handle | |
def register_state_dict_pre_hook( | |
self, hook: Callable[["Optimizer"], None], prepend: bool = False | |
) -> RemovableHandle: | |
r"""Register a state dict pre-hook which will be called before | |
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the | |
following signature:: | |
hook(optimizer) -> None | |
The ``optimizer`` argument is the optimizer instance being used. | |
The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. | |
The registered hook can be used to perform pre-processing before the ``state_dict`` | |
call is made. | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
prepend (bool): If True, the provided pre ``hook`` will be fired before | |
all the already registered pre-hooks on ``state_dict``. Otherwise, | |
the provided ``hook`` will be fired after all the already registered | |
pre-hooks. (default: False) | |
Returns: | |
:class:`torch.utils.hooks.RemoveableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) | |
self._optimizer_state_dict_pre_hooks[handle.id] = hook | |
if prepend: | |
self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) | |
return handle | |
def register_state_dict_post_hook( | |
self, | |
hook: Callable[["Optimizer", StateDict], Optional[StateDict]], | |
prepend: bool = False, | |
) -> RemovableHandle: | |
r"""Register a state dict post-hook which will be called after | |
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the | |
following signature:: | |
hook(optimizer, state_dict) -> state_dict or None | |
The hook will be called with arguments ``self`` and ``state_dict`` after generating | |
a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally | |
return a new one. The registered hook can be used to perform post-processing | |
on the ``state_dict`` before it is returned. | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
prepend (bool): If True, the provided post ``hook`` will be fired before | |
all the already registered post-hooks on ``state_dict``. Otherwise, | |
the provided ``hook`` will be fired after all the already registered | |
post-hooks. (default: False) | |
Returns: | |
:class:`torch.utils.hooks.RemoveableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) | |
self._optimizer_state_dict_post_hooks[handle.id] = hook | |
if prepend: | |
self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) | |
return handle | |
def state_dict(self) -> StateDict: | |
r"""Returns the state of the optimizer as a :class:`dict`. | |
It contains two entries: | |
* ``state``: a Dict holding current optimization state. Its content | |
differs between optimizer classes, but some common characteristics | |
hold. For example, state is saved per parameter, and the parameter | |
itself is NOT saved. ``state`` is a Dictionary mapping parameter ids | |
to a Dict with state corresponding to each parameter. | |
* ``param_groups``: a List containing all parameter groups where each | |
parameter group is a Dict. Each parameter group contains metadata | |
specific to the optimizer, such as learning rate and weight decay, | |
as well as a List of parameter IDs of the parameters in the group. | |
NOTE: The parameter IDs may look like indices but they are just IDs | |
associating state with param_group. When loading from a state_dict, | |
the optimizer will zip the param_group ``params`` (int IDs) and the | |
optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to | |
match state WITHOUT additional verification. | |
A returned state dict might look something like: | |
.. code-block:: text | |
{ | |
'state': { | |
0: {'momentum_buffer': tensor(...), ...}, | |
1: {'momentum_buffer': tensor(...), ...}, | |
2: {'momentum_buffer': tensor(...), ...}, | |
3: {'momentum_buffer': tensor(...), ...} | |
}, | |
'param_groups': [ | |
{ | |
'lr': 0.01, | |
'weight_decay': 0, | |
... | |
'params': [0] | |
}, | |
{ | |
'lr': 0.001, | |
'weight_decay': 0.5, | |
... | |
'params': [1, 2, 3] | |
} | |
] | |
} | |
""" | |
for pre_hook in self._optimizer_state_dict_pre_hooks.values(): | |
pre_hook(self) | |
# Save order indices instead of Tensors | |
param_mappings: Dict[int, int] = {} | |
start_index = 0 | |
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: | |
nonlocal start_index | |
packed = {k: v for k, v in group.items() if k != 'params'} | |
param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index) | |
if id(p) not in param_mappings}) | |
packed['params'] = [param_mappings[id(p)] for p in group['params']] | |
start_index += len(packed['params']) | |
return packed | |
param_groups = [pack_group(g) for g in self.param_groups] | |
# Remap state to use order indices as keys | |
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v | |
for k, v in self.state.items()} | |
state_dict = { | |
'state': packed_state, | |
'param_groups': param_groups, | |
} | |
for post_hook in self._optimizer_state_dict_post_hooks.values(): | |
hook_result = post_hook(self, state_dict) | |
if hook_result is not None: | |
state_dict = hook_result | |
return state_dict | |
def _process_value_according_to_param_policy( | |
param: torch.Tensor, | |
value: torch.Tensor, | |
param_id: int, | |
param_groups: List[Dict[Any, Any]], | |
key: Hashable = None, | |
) -> torch.Tensor: | |
# Floating-point types are a bit special here. They are the only ones | |
# that are assumed to always match the type of params. | |
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 | |
# UNLESS fused or capturable, see note [special device hosting for step] | |
fused = False | |
capturable = False | |
assert param_groups is not None | |
for pg in param_groups: | |
if param_id in pg["params"]: | |
fused = pg["fused"] if "fused" in pg else False | |
capturable = pg["capturable"] if "capturable" in pg else False | |
break | |
if key == "step": | |
if capturable or fused: | |
return value.to(dtype=torch.float32, device=param.device) | |
else: | |
return value | |
else: | |
if param.is_floating_point(): | |
return value.to(dtype=param.dtype, device=param.device) | |
else: | |
return value.to(device=param.device) | |
def register_load_state_dict_pre_hook( | |
self, | |
hook: Callable[["Optimizer", StateDict], Optional[StateDict]], | |
prepend: bool = False, | |
) -> RemovableHandle: | |
r"""Register a load_state_dict pre-hook which will be called before | |
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the | |
following signature:: | |
hook(optimizer, state_dict) -> state_dict or None | |
The ``optimizer`` argument is the optimizer instance being used and the | |
``state_dict`` argument is a shallow copy of the ``state_dict`` the user | |
passed in to ``load_state_dict``. The hook may modify the state_dict inplace | |
or optionally return a new one. If a state_dict is returned, it will be used | |
to be loaded into the optimizer. | |
The hook will be called with argument ``self`` and ``state_dict`` before | |
calling ``load_state_dict`` on ``self``. The registered hook can be used to | |
perform pre-processing before the ``load_state_dict`` call is made. | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
prepend (bool): If True, the provided pre ``hook`` will be fired before | |
all the already registered pre-hooks on ``load_state_dict``. Otherwise, | |
the provided ``hook`` will be fired after all the already registered | |
pre-hooks. (default: False) | |
Returns: | |
:class:`torch.utils.hooks.RemoveableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) | |
self._optimizer_load_state_dict_pre_hooks[handle.id] = hook | |
if prepend: | |
self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) | |
return handle | |
def register_load_state_dict_post_hook( | |
self, hook: Callable[["Optimizer"], None], prepend: bool = False | |
) -> RemovableHandle: | |
r"""Register a load_state_dict post-hook which will be called after | |
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the | |
following signature:: | |
hook(optimizer) -> None | |
The ``optimizer`` argument is the optimizer instance being used. | |
The hook will be called with argument ``self`` after calling | |
``load_state_dict`` on ``self``. The registered hook can be used to | |
perform post-processing after ``load_state_dict`` has loaded the | |
``state_dict``. | |
Args: | |
hook (Callable): The user defined hook to be registered. | |
prepend (bool): If True, the provided post ``hook`` will be fired before | |
all the already registered post-hooks on ``load_state_dict``. Otherwise, | |
the provided ``hook`` will be fired after all the already registered | |
post-hooks. (default: False) | |
Returns: | |
:class:`torch.utils.hooks.RemoveableHandle`: | |
a handle that can be used to remove the added hook by calling | |
``handle.remove()`` | |
""" | |
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) | |
self._optimizer_load_state_dict_post_hooks[handle.id] = hook | |
if prepend: | |
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] | |
return handle | |
def load_state_dict(self, state_dict: StateDict) -> None: | |
r"""Loads the optimizer state. | |
Args: | |
state_dict (dict): optimizer state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
# shallow copy, to be consistent with module API | |
state_dict = state_dict.copy() | |
for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): | |
hook_result = pre_hook(self, state_dict) | |
if hook_result is not None: | |
state_dict = hook_result | |
# Validate the state_dict | |
groups = self.param_groups | |
# Deepcopy as we write into saved_groups later to update state | |
saved_groups = deepcopy(state_dict['param_groups']) | |
if len(groups) != len(saved_groups): | |
raise ValueError("loaded state dict has a different number of " | |
"parameter groups") | |
param_lens = (len(g['params']) for g in groups) | |
saved_lens = (len(g['params']) for g in saved_groups) | |
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): | |
raise ValueError("loaded state dict contains a parameter group " | |
"that doesn't match the size of optimizer's group") | |
# Update the state | |
id_map = dict(zip(chain.from_iterable(g['params'] for g in saved_groups), | |
chain.from_iterable(g['params'] for g in groups))) | |
def _cast(param, value, param_id=None, param_groups=None, key=None): | |
r"""Make a deep copy of value, casting all tensors to device of param.""" | |
if isinstance(value, torch.Tensor): | |
return Optimizer._process_value_according_to_param_policy(param, value, param_id, param_groups, key) | |
elif isinstance(value, dict): | |
return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} | |
elif isinstance(value, Iterable): | |
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] | |
else: | |
return value | |
# Copy state assigned to params (and cast tensors to appropriate types). | |
# State that is not assigned to params is copied as is (needed for | |
# backward compatibility). | |
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) | |
for k, v in state_dict['state'].items(): | |
if k in id_map: | |
param = id_map[k] | |
state[param] = _cast(param, v, param_id=k, param_groups=state_dict['param_groups']) | |
else: | |
state[k] = v | |
# Update parameter groups, setting their 'params' value | |
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: | |
new_group['params'] = group['params'] | |
return new_group | |
param_groups = [ | |
update_group(g, ng) for g, ng in zip(groups, saved_groups)] | |
self.__setstate__({'state': state, 'param_groups': param_groups}) | |
for post_hook in self._optimizer_load_state_dict_post_hooks.values(): | |
post_hook(self) | |
def zero_grad(self, set_to_none: bool = True) -> None: | |
r"""Resets the gradients of all optimized :class:`torch.Tensor` s. | |
Args: | |
set_to_none (bool): instead of setting to zero, set the grads to None. | |
This will in general have lower memory footprint, and can modestly improve performance. | |
However, it changes certain behaviors. For example: | |
1. When the user tries to access a gradient and perform manual ops on it, | |
a None attribute or a Tensor full of 0s will behave differently. | |
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s | |
are guaranteed to be None for params that did not receive a gradient. | |
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None | |
(in one case it does the step with a gradient of 0 and in the other it skips | |
the step altogether). | |
""" | |
foreach = self.defaults.get('foreach', False) or self.defaults.get('fused', False) | |
if not hasattr(self, "_zero_grad_profile_name"): | |
self._patch_step_function() | |
per_device_and_dtype_grads: Optional[DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]]] | |
if foreach: | |
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) | |
else: | |
per_device_and_dtype_grads = None | |
with torch.autograd.profiler.record_function(self._zero_grad_profile_name): | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is not None: | |
if set_to_none: | |
p.grad = None | |
else: | |
if p.grad.grad_fn is not None: | |
p.grad.detach_() | |
else: | |
p.grad.requires_grad_(False) | |
if (not foreach or p.grad.is_sparse): | |
p.grad.zero_() | |
else: | |
assert per_device_and_dtype_grads is not None | |
per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) | |
if foreach: | |
assert per_device_and_dtype_grads is not None | |
for per_dtype_grads in per_device_and_dtype_grads.values(): | |
for grads in per_dtype_grads.values(): | |
torch._foreach_zero_(grads) | |
def step(self, closure: None = ...) -> None: | |
... | |
def step(self, closure: Callable[[], float]) -> float: | |
... | |
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: | |
r"""Performs a single optimization step (parameter update). | |
Args: | |
closure (Callable): A closure that reevaluates the model and | |
returns the loss. Optional for most optimizers. | |
.. note:: | |
Unless otherwise specified, this function should not modify the | |
``.grad`` field of the parameters. | |
""" | |
raise NotImplementedError | |
def add_param_group(self, param_group: Dict[str, Any]) -> None: | |
r"""Add a param group to the :class:`Optimizer` s `param_groups`. | |
This can be useful when fine tuning a pre-trained network as frozen layers can be made | |
trainable and added to the :class:`Optimizer` as training progresses. | |
Args: | |
param_group (dict): Specifies what Tensors should be optimized along with group | |
specific optimization options. | |
""" | |
if not isinstance(param_group, dict): | |
raise TypeError(f"param_group must be a dict, but got {type(param_group)}") | |
params = param_group['params'] | |
if isinstance(params, torch.Tensor): | |
param_group['params'] = [params] | |
elif isinstance(params, set): | |
raise TypeError('optimizer parameters need to be organized in ordered collections, but ' | |
'the ordering of tensors in sets will change between runs. Please use a list instead.') | |
else: | |
param_group['params'] = list(params) | |
for param in param_group['params']: | |
if not isinstance(param, torch.Tensor): | |
raise TypeError("optimizer can only optimize Tensors, " | |
"but one of the params is " + torch.typename(param)) | |
if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad): | |
raise ValueError("can't optimize a non-leaf Tensor") | |
for name, default in self.defaults.items(): | |
if default is required and name not in param_group: | |
raise ValueError(f"parameter group didn't specify a value of required optimization parameter {name}") | |
else: | |
param_group.setdefault(name, default) | |
params = param_group['params'] | |
if len(params) != len(set(params)): | |
warnings.warn("optimizer contains a parameter group with duplicate parameters; " | |
"in future, this will cause an error; " | |
"see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3) | |
param_set: Set[torch.Tensor] = set() | |
for group in self.param_groups: | |
param_set.update(set(group['params'])) | |
if not param_set.isdisjoint(set(param_group['params'])): | |
raise ValueError("some parameters appear in more than one parameter group") | |
self.param_groups.append(param_group) | |