Spaces:
Sleeping
Sleeping
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
import warnings | |
import torch | |
from .state import AcceleratorState, GradientState | |
from .utils import DistributedType, honor_type, is_tpu_available | |
if is_tpu_available(check_device=False): | |
import torch_xla.core.xla_model as xm | |
def move_to_device(state, device): | |
if isinstance(state, (list, tuple)): | |
return honor_type(state, (move_to_device(t, device) for t in state)) | |
elif isinstance(state, dict): | |
return type(state)({k: move_to_device(v, device) for k, v in state.items()}) | |
elif isinstance(state, torch.Tensor): | |
return state.to(device) | |
return state | |
class AcceleratedOptimizer(torch.optim.Optimizer): | |
""" | |
Internal wrapper around a torch optimizer. | |
Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient | |
accumulation. | |
Args: | |
optimizer (`torch.optim.optimizer.Optimizer`): | |
The optimizer to wrap. | |
device_placement (`bool`, *optional*, defaults to `True`): | |
Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of | |
`optimizer` on the right device. | |
scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*): | |
The scaler to use in the step function if training with mixed precision. | |
""" | |
def __init__(self, optimizer, device_placement=True, scaler=None): | |
self.optimizer = optimizer | |
self.scaler = scaler | |
self.accelerator_state = AcceleratorState() | |
self.gradient_state = GradientState() | |
self.device_placement = device_placement | |
self._is_overflow = False | |
self._last_scale = None | |
# Handle device placement | |
if device_placement: | |
state_dict = self.optimizer.state_dict() | |
if self.accelerator_state.distributed_type == DistributedType.TPU: | |
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) | |
else: | |
state_dict = move_to_device(state_dict, self.accelerator_state.device) | |
self.optimizer.load_state_dict(state_dict) | |
def state(self): | |
return self.optimizer.state | |
def state(self, state): | |
self.optimizer.state = state | |
def param_groups(self): | |
return self.optimizer.param_groups | |
def param_groups(self, param_groups): | |
self.optimizer.param_groups = param_groups | |
def defaults(self): | |
return self.optimizer.defaults | |
def defaults(self, defaults): | |
self.optimizer.defaults = defaults | |
def add_param_group(self, param_group): | |
self.optimizer.add_param_group(param_group) | |
def load_state_dict(self, state_dict): | |
if self.accelerator_state.distributed_type == DistributedType.TPU and self.device_placement: | |
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) | |
self.optimizer.load_state_dict(state_dict) | |
def state_dict(self): | |
return self.optimizer.state_dict() | |
def zero_grad(self, set_to_none=None): | |
if self.gradient_state.sync_gradients: | |
accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters | |
if accept_arg: | |
if set_to_none is None: | |
set_to_none = False | |
self.optimizer.zero_grad(set_to_none=set_to_none) | |
else: | |
if set_to_none is not None: | |
raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.") | |
self.optimizer.zero_grad() | |
def step(self, closure=None): | |
if self.gradient_state.sync_gradients: | |
if self.accelerator_state.distributed_type == DistributedType.TPU: | |
optimizer_args = {"closure": closure} if closure is not None else {} | |
xm.optimizer_step(self.optimizer, optimizer_args=optimizer_args) | |
elif self.scaler is not None: | |
new_scale = False | |
if self._last_scale is None: | |
# `get_scale` is an async operation requiring full synchronization | |
# on CPU and GPUs before finishing. As a result, we store away | |
# the prior one to reduce the call overhead | |
self._last_scale = self.scaler.get_scale() | |
new_scale = True | |
self.scaler.step(self.optimizer, closure) | |
self.scaler.update() | |
scale_after = self.scaler.get_scale() | |
if not new_scale: | |
# If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow. | |
self._is_overflow = scale_after < self._last_scale | |
self._last_scale = scale_after | |
else: | |
self.optimizer.step(closure) | |
def _switch_parameters(self, parameters_map): | |
for param_group in self.optimizer.param_groups: | |
param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]] | |
def is_overflow(self): | |
"""Whether or not the optimizer step was done, or skipped because of gradient overflow.""" | |
warnings.warn( | |
"The `is_overflow` property is deprecated and will be removed in version 1.0 of Accelerate use " | |
"`optimizer.step_was_skipped` instead.", | |
FutureWarning, | |
) | |
return self._is_overflow | |
def step_was_skipped(self): | |
"""Whether or not the optimizer step was skipped.""" | |
return self._is_overflow | |
def __getstate__(self): | |
return self.__dict__.copy() | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |