Spaces:
Running
Running
from __future__ import annotations | |
import inspect | |
import warnings | |
from collections import abc, defaultdict | |
from enum import Enum | |
from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union | |
import torch | |
__all__ = ["OptState", "GradScaler"] | |
class _MultiDeviceReplicator: | |
"""Lazily serves copies of a tensor to requested devices. | |
Copies are cached per-device. | |
""" | |
def __init__(self, master_tensor: torch.Tensor) -> None: | |
self.master = master_tensor | |
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} | |
def get(self, device: torch.device) -> torch.Tensor: | |
retval = self._per_device_tensors.get(device, None) | |
if retval is None: | |
retval = self.master.to(device=device, non_blocking=True, copy=True) | |
self._per_device_tensors[device] = retval | |
return retval | |
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, | |
# as well as associated "enum" values. Prefers defining these at top level because | |
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. | |
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler | |
# causes a circular reference, which we'd rather avoid. | |
class OptState(Enum): | |
READY = 0 | |
UNSCALED = 1 | |
STEPPED = 2 | |
def _refresh_per_optimizer_state() -> Dict[str, Any]: | |
return {"stage": OptState.READY, "found_inf_per_device": {}} | |
class GradScaler: | |
"""An instance ``scaler`` of :class:`GradScaler`. | |
Helps perform the steps of gradient scaling | |
conveniently. | |
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. | |
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. | |
* ``scaler.update()`` updates ``scaler``'s scale factor. | |
Example:: | |
# Creates a GradScaler once at the beginning of training. | |
scaler = GradScaler() | |
for epoch in epochs: | |
for input, target in data: | |
optimizer.zero_grad() | |
output = model(input) | |
loss = loss_fn(output, target) | |
# Scales loss. Calls backward() on scaled loss to create scaled gradients. | |
scaler.scale(loss).backward() | |
# scaler.step() first unscales gradients of the optimizer's params. | |
# If gradients don't contain infs/NaNs, optimizer.step() is then called, | |
# otherwise, optimizer.step() is skipped. | |
scaler.step(optimizer) | |
# Updates the scale for next iteration. | |
scaler.update() | |
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage | |
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, | |
and multiple losses/optimizers. | |
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, | |
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if | |
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used | |
without incurring inf or NaN gradient values. | |
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every | |
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). | |
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params | |
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. | |
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. | |
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by | |
``growth_factor``. | |
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its | |
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these | |
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). | |
Args: | |
device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'. | |
The type is the same as the `type` attribute of a :class:`torch.device`. | |
Thus, you may obtain the device type of a tensor using `Tensor.device.type`. | |
init_scale (float, optional, default=2.**16): Initial scale factor. | |
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during | |
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. | |
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during | |
:meth:`update` if inf/NaN gradients occur in an iteration. | |
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients | |
that must occur for the scale to be multiplied by ``growth_factor``. | |
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply | |
invokes the underlying ``optimizer.step()``, and other methods become no-ops. | |
Default: ``True`` | |
""" | |
def __init__( | |
self, | |
device: str = "cuda", | |
init_scale: float = 2.0**16, | |
growth_factor: float = 2.0, | |
backoff_factor: float = 0.5, | |
growth_interval: int = 2000, | |
enabled: bool = True, | |
) -> None: | |
self._device = device | |
self._enabled = enabled | |
if self._device == "cuda": | |
if enabled and torch.cuda.amp.common.amp_definitely_not_available(): | |
warnings.warn( | |
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling." | |
) | |
self._enabled = False | |
if self._enabled: | |
assert growth_factor > 1.0, "The growth factor must be > 1.0." | |
assert backoff_factor < 1.0, "The backoff factor must be < 1.0." | |
self._init_scale = init_scale | |
# self._scale will be lazily initialized during the first call to scale() | |
self._scale: Optional[torch.Tensor] = None | |
self._growth_factor = growth_factor | |
self._backoff_factor = backoff_factor | |
self._growth_interval = growth_interval | |
self._init_growth_tracker = 0 | |
# self._growth_tracker will be lazily initialized during the first call to scale() | |
self._growth_tracker: Optional[torch.Tensor] = None | |
self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict( | |
_refresh_per_optimizer_state | |
) | |
def _check_scale_growth_tracker( | |
self, funcname: str | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." | |
assert self._scale is not None, ( | |
f"Attempted {funcname} but _scale is None. " + fix | |
) | |
assert self._growth_tracker is not None, ( | |
f"Attempted {funcname} but _growth_tracker is None. " + fix | |
) | |
return (self._scale, self._growth_tracker) | |
def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None: | |
assert self._growth_tracker is None, "_growth_tracker initialized before _scale" | |
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev) | |
self._growth_tracker = torch.full( | |
(), self._init_growth_tracker, dtype=torch.int32, device=dev | |
) | |
def scale(self, outputs: torch.Tensor) -> torch.Tensor: | |
... | |
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]: | |
... | |
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: | |
... | |
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: | |
... | |
def scale( | |
self, | |
outputs: Union[torch.Tensor, Iterable[torch.Tensor]], | |
) -> Union[torch.Tensor, Iterable[torch.Tensor]]: | |
""" | |
Multiplies ('scales') a tensor or list of tensors by the scale factor. | |
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned | |
unmodified. | |
Args: | |
outputs (Tensor or iterable of Tensors): Outputs to scale. | |
""" | |
if not self._enabled: | |
return outputs | |
# Short-circuit for the common case. | |
if isinstance(outputs, torch.Tensor): | |
if self._scale is None: | |
self._lazy_init_scale_growth_tracker(outputs.device) | |
assert self._scale is not None | |
return outputs * self._scale.to(device=outputs.device, non_blocking=True) | |
# Invoke the more complex machinery only if we're treating multiple outputs. | |
stash: List[ | |
_MultiDeviceReplicator | |
] = [] # holds a reference that can be overwritten by apply_scale | |
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): | |
if isinstance(val, torch.Tensor): | |
if len(stash) == 0: | |
if self._scale is None: | |
self._lazy_init_scale_growth_tracker(val.device) | |
assert self._scale is not None | |
stash.append(_MultiDeviceReplicator(self._scale)) | |
return val * stash[0].get(val.device) | |
if isinstance(val, abc.Iterable): | |
iterable = map(apply_scale, val) | |
if isinstance(val, (list, tuple)): | |
return type(val)(iterable) | |
return iterable | |
raise ValueError("outputs must be a Tensor or an iterable of Tensors") | |
return apply_scale(outputs) | |
def _unscale_grads_( | |
self, | |
optimizer: torch.optim.Optimizer, | |
inv_scale: torch.Tensor, | |
found_inf: torch.Tensor, | |
allow_fp16: bool, | |
) -> Dict[torch.device, torch.Tensor]: | |
per_device_inv_scale = _MultiDeviceReplicator(inv_scale) | |
per_device_found_inf = _MultiDeviceReplicator(found_inf) | |
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. | |
# There could be hundreds of grads, so we'd like to iterate through them just once. | |
# However, we don't know their devices or dtypes in advance. | |
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict | |
# Google says mypy struggles with defaultdicts type annotations. | |
per_device_and_dtype_grads: Dict[ | |
torch.device, Dict[torch.dtype, List[torch.Tensor]] | |
] = defaultdict(lambda: defaultdict(list)) | |
with torch.no_grad(): | |
for group in optimizer.param_groups: | |
for param in group["params"]: | |
assert isinstance(param, torch.Tensor) | |
if param.grad is None: | |
continue | |
if (not allow_fp16) and param.grad.dtype == torch.float16: | |
raise ValueError("Attempting to unscale FP16 gradients.") | |
if param.grad.is_sparse: | |
# is_coalesced() == False means the sparse grad has values with duplicate indices. | |
# coalesce() deduplicates indices and adds all values that have the same index. | |
# For scaled fp16 values, there's a good chance coalescing will cause overflow, | |
# so we should check the coalesced _values(). | |
if param.grad.dtype is torch.float16: | |
param.grad = param.grad.coalesce() | |
to_unscale = param.grad._values() | |
else: | |
to_unscale = param.grad | |
# TODO: is there a way to split by device and dtype without appending in the inner loop? | |
per_device_and_dtype_grads[to_unscale.device][ | |
to_unscale.dtype | |
].append(to_unscale) | |
for device, per_dtype_grads in per_device_and_dtype_grads.items(): | |
for grads in per_dtype_grads.values(): | |
torch._amp_foreach_non_finite_check_and_unscale_( | |
grads, | |
per_device_found_inf.get(device), | |
per_device_inv_scale.get(device), | |
) | |
return per_device_found_inf._per_device_tensors | |
def unscale_(self, optimizer: torch.optim.Optimizer) -> None: | |
""" | |
Divides ("unscales") the optimizer's gradient tensors by the scale factor. | |
:meth:`unscale_` is optional, serving cases where you need to | |
:ref:`modify or inspect gradients<working-with-unscaled-gradients>` | |
between the backward pass(es) and :meth:`step`. | |
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. | |
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: | |
... | |
scaler.scale(loss).backward() | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) | |
scaler.step(optimizer) | |
scaler.update() | |
Args: | |
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. | |
.. note:: | |
:meth:`unscale_` does not incur a CPU-GPU sync. | |
.. warning:: | |
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call, | |
and only after all gradients for that optimizer's assigned parameters have been accumulated. | |
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. | |
.. warning:: | |
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. | |
""" | |
if not self._enabled: | |
return | |
self._check_scale_growth_tracker("unscale_") | |
optimizer_state = self._per_optimizer_states[id(optimizer)] | |
if optimizer_state["stage"] is OptState.UNSCALED: | |
raise RuntimeError( | |
"unscale_() has already been called on this optimizer since the last update()." | |
) | |
elif optimizer_state["stage"] is OptState.STEPPED: | |
raise RuntimeError("unscale_() is being called after step().") | |
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. | |
assert self._scale is not None | |
inv_scale = self._scale.double().reciprocal().float() | |
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device) | |
optimizer_state["found_inf_per_device"] = self._unscale_grads_( | |
optimizer, inv_scale, found_inf, False | |
) | |
optimizer_state["stage"] = OptState.UNSCALED | |
def _maybe_opt_step( | |
self, | |
optimizer: torch.optim.Optimizer, | |
optimizer_state: Dict[str, Any], | |
*args: Any, | |
**kwargs: Any, | |
) -> Optional[float]: | |
retval: Optional[float] = None | |
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): | |
retval = optimizer.step(*args, **kwargs) | |
return retval | |
def step( | |
self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any | |
) -> Optional[float]: | |
"""Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN. | |
:meth:`step` carries out the following two operations: | |
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` | |
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. | |
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled | |
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. | |
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. | |
Returns the return value of ``optimizer.step(*args, **kwargs)``. | |
Args: | |
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. | |
args: Any arguments. | |
kwargs: Any keyword arguments. | |
.. warning:: | |
Closure use is not currently supported. | |
""" | |
if not self._enabled: | |
return optimizer.step(*args, **kwargs) | |
if "closure" in kwargs: | |
raise RuntimeError( | |
"Closure use is not currently supported if GradScaler is enabled." | |
) | |
self._check_scale_growth_tracker("step") | |
optimizer_state = self._per_optimizer_states[id(optimizer)] | |
if optimizer_state["stage"] is OptState.STEPPED: | |
raise RuntimeError( | |
"step() has already been called since the last update()." | |
) | |
retval: Optional[float] = None | |
if getattr(optimizer, "_step_supports_amp_scaling", False): | |
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. | |
# The contract with custom optimizers is that their step() should accept an additional, | |
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: | |
# it can query its own state, invoke unscale_ on itself, etc | |
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument | |
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` | |
# and `found_inf` to the passed optimizer so that the optimizer can utilize those | |
# to skip the parameter updates or unscale gradients before updating parameters in | |
# the fused kernel, e.g. `FusedAdamMathFunctor`. | |
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, | |
# while the method is expected to be called by users side, i.e. their optimizers. | |
kwargs_ = kwargs | |
has_grad_scaler_kwarg = ( | |
"grad_scaler" in inspect.signature(optimizer.step).parameters | |
) | |
if has_grad_scaler_kwarg: | |
warnings.warn( | |
"GradScaler is going to stop passing itself as a keyword argument to the passed " | |
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " | |
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", | |
FutureWarning, | |
) | |
kwargs_.update({"grad_scaler": self}) | |
else: | |
if optimizer_state["stage"] is OptState.READY: | |
self._check_inf_per_device(optimizer) | |
scaler = self._get_scale_async() | |
assert scaler is not None | |
found_inf = cast( | |
torch.Tensor, | |
sum( | |
[ | |
t.to(scaler.device, non_blocking=True) | |
for t in optimizer_state["found_inf_per_device"].values() | |
] | |
), | |
) | |
optimizer.grad_scale = ( # type: ignore[attr-defined] | |
None if optimizer_state["stage"] == OptState.UNSCALED else scaler | |
) | |
optimizer.found_inf = found_inf # type: ignore[attr-defined] | |
retval = optimizer.step(*args, **kwargs_) | |
optimizer_state["stage"] = OptState.STEPPED | |
if not has_grad_scaler_kwarg: | |
del optimizer.grad_scale # type: ignore[attr-defined] | |
del optimizer.found_inf # type: ignore[attr-defined] | |
return retval | |
if optimizer_state["stage"] is OptState.READY: | |
self.unscale_(optimizer) | |
assert ( | |
len(optimizer_state["found_inf_per_device"]) > 0 | |
), "No inf checks were recorded for this optimizer." | |
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) | |
optimizer_state["stage"] = OptState.STEPPED | |
return retval | |
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None: | |
"""Update the scale factor. | |
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` | |
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, | |
the scale is multiplied by ``growth_factor`` to increase it. | |
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not | |
used directly, it's used to fill GradScaler's internal scale tensor. So if | |
``new_scale`` was a tensor, later in-place changes to that tensor will not further | |
affect the scale GradScaler uses internally.) | |
Args: | |
new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor. | |
.. warning:: | |
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has | |
been invoked for all optimizers used this iteration. | |
.. warning:: | |
For performance reasons, we do not check the scale factor value to avoid synchronizations, | |
so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or | |
you are seeing NaNs in your gradients or loss, something is likely wrong. For example, | |
bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges. | |
""" | |
if not self._enabled: | |
return | |
_scale, _growth_tracker = self._check_scale_growth_tracker("update") | |
if new_scale is not None: | |
assert self._scale is not None | |
# Accept a new user-defined scale. | |
if isinstance(new_scale, float): | |
self._scale.fill_(new_scale) | |
else: | |
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ | |
torch.FloatTensor with requires_grad=False." | |
assert new_scale.device.type == self._device, reason | |
assert new_scale.numel() == 1, reason | |
assert new_scale.requires_grad is False, reason | |
self._scale.copy_(new_scale) | |
else: | |
# Consume shared inf/nan data collected from optimizers to update the scale. | |
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. | |
found_infs = [ | |
found_inf.to(device=_scale.device, non_blocking=True) | |
for state in self._per_optimizer_states.values() | |
for found_inf in state["found_inf_per_device"].values() | |
] | |
assert len(found_infs) > 0, "No inf checks were recorded prior to update." | |
found_inf_combined = found_infs[0] | |
if len(found_infs) > 1: | |
for i in range(1, len(found_infs)): | |
found_inf_combined += found_infs[i] | |
torch._amp_update_scale_( | |
_scale, | |
_growth_tracker, | |
found_inf_combined, | |
self._growth_factor, | |
self._backoff_factor, | |
self._growth_interval, | |
) | |
# To prepare for next iteration, clear the data collected from optimizers this iteration. | |
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) | |
def _get_scale_async(self) -> Optional[torch.Tensor]: | |
return self._scale | |
def get_scale(self) -> float: | |
"""Return a Python float containing the current scale, or 1.0 if scaling is disabled. | |
.. warning:: | |
:meth:`get_scale` incurs a CPU-GPU sync. | |
""" | |
if self._enabled: | |
return ( | |
self._init_scale | |
if (scale := self._get_scale_async()) is None | |
else cast(float, scale.item()) | |
) | |
return 1.0 | |
def get_growth_factor(self) -> float: | |
r"""Return a Python float containing the scale growth factor.""" | |
return self._growth_factor | |
def set_growth_factor(self, new_factor: float) -> None: | |
r"""Set a new scale growth factor. | |
Args: | |
new_scale (float): Value to use as the new scale growth factor. | |
""" | |
self._growth_factor = new_factor | |
def get_backoff_factor(self) -> float: | |
r"""Return a Python float containing the scale backoff factor.""" | |
return self._backoff_factor | |
def set_backoff_factor(self, new_factor: float) -> None: | |
r"""Set a new scale backoff factor. | |
Args: | |
new_scale (float): Value to use as the new scale backoff factor. | |
""" | |
self._backoff_factor = new_factor | |
def get_growth_interval(self) -> int: | |
r"""Return a Python int containing the growth interval.""" | |
return self._growth_interval | |
def set_growth_interval(self, new_interval: int) -> None: | |
r"""Set a new growth interval. | |
Args: | |
new_interval (int): Value to use as the new growth interval. | |
""" | |
self._growth_interval = new_interval | |
def _get_growth_tracker(self) -> int: | |
if self._enabled: | |
return ( | |
self._init_growth_tracker | |
if self._growth_tracker is None | |
else cast(int, self._growth_tracker.item()) | |
) | |
return 0 | |
def is_enabled(self) -> bool: | |
r"""Return a bool indicating whether this instance is enabled.""" | |
return self._enabled | |
def state_dict(self) -> Dict[str, Any]: | |
r"""Return the state of the scaler as a :class:`dict`. | |
It contains five entries: | |
* ``"scale"`` - a Python float containing the current scale | |
* ``"growth_factor"`` - a Python float containing the current growth factor | |
* ``"backoff_factor"`` - a Python float containing the current backoff factor | |
* ``"growth_interval"`` - a Python int containing the current growth interval | |
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. | |
If this instance is not enabled, returns an empty dict. | |
.. note:: | |
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` | |
should be called after :meth:`update`. | |
""" | |
if self._enabled: | |
return { | |
"scale": self.get_scale(), | |
"growth_factor": self._growth_factor, | |
"backoff_factor": self._backoff_factor, | |
"growth_interval": self._growth_interval, | |
"_growth_tracker": self._get_growth_tracker(), | |
} | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
r"""Load the scaler state. | |
If this instance is disabled, :meth:`load_state_dict` is a no-op. | |
Args: | |
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. | |
""" | |
if not self._enabled: | |
return | |
if len(state_dict) == 0: | |
raise RuntimeError( | |
"The source state dict is empty, possibly because it was saved " | |
"from a disabled instance of GradScaler." | |
) | |
self._init_scale = cast(float, state_dict["scale"]) | |
if self._scale is not None: | |
self._scale.fill_(state_dict["scale"]) | |
self._growth_factor = cast(float, state_dict["growth_factor"]) | |
self._backoff_factor = cast(float, state_dict["backoff_factor"]) | |
self._growth_interval = cast(int, state_dict["growth_interval"]) | |
self._init_growth_tracker = cast(int, state_dict["_growth_tracker"]) | |
if self._growth_tracker is not None: | |
self._growth_tracker.fill_(state_dict["_growth_tracker"]) | |
def __getstate__(self) -> Dict[str, Any]: | |
state = self.__dict__.copy() | |
if self._enabled: | |
assert len(self._per_optimizer_states) == 0, ( | |
"A GradScaler instance may only be pickled at the beginning " | |
"of an iteration, or at the end after scaler.update()." | |
) | |
# Pickling _scale and _growth_tracker Tensors directly triggers | |
# "warnings.warn("pickle support for Storage will be removed in 1.5..." | |
# so instead, we set the unpickled instance up to reinitialize them lazily. | |
state["_init_scale"] = self.get_scale() | |
state["_init_growth_tracker"] = self._get_growth_tracker() | |
state["_scale"] = None | |
state["_growth_tracker"] = None | |
return state | |
def __setstate__(self, state: Dict[str, Any]) -> None: | |
self.__dict__.update(state) | |
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]: | |
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") | |
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) | |
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device) | |
self._per_optimizer_states[id(optimizer)][ | |
"found_inf_per_device" | |
] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) | |
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] | |
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]: | |
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] | |