Spaces:
Sleeping
Sleeping
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation | |
import warnings | |
from .state import AcceleratorState, GradientState | |
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler") | |
class AcceleratedScheduler: | |
""" | |
A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful | |
to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed | |
precision training) | |
When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always | |
step the scheduler to account for it. | |
Args: | |
scheduler (`torch.optim.lr_scheduler._LRScheduler`): | |
The scheduler to wrap. | |
optimizers (one or a list of `torch.optim.Optimizer`): | |
The optimizers used. | |
step_with_optimizer (`bool`, *optional*, defaults to `True`): | |
Whether or not the scheduler should be stepped at each optimizer step. | |
split_batches (`bool`, *optional*, defaults to `False`): | |
Whether or not the dataloaders split one batch across the different processes (so batch size is the same | |
regardless of the number of processes) or create batches on each process (so batch size is the original | |
batch size multiplied by the number of processes). | |
""" | |
def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False): | |
self.scheduler = scheduler | |
self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers] | |
self.split_batches = split_batches | |
self.step_with_optimizer = step_with_optimizer | |
self.gradient_state = GradientState() | |
def step(self, *args, **kwargs): | |
if not self.step_with_optimizer: | |
# No link between scheduler and optimizer -> just step | |
self.scheduler.step(*args, **kwargs) | |
return | |
# Otherwise, first make sure the optimizer was stepped. | |
if not self.gradient_state.sync_gradients: | |
if self.gradient_state.adjust_scheduler: | |
self.scheduler._step_count += 1 | |
return | |
for opt in self.optimizers: | |
if opt.step_was_skipped: | |
return | |
if self.split_batches: | |
# Split batches -> the training dataloader batch size is not changed so one step per training step | |
self.scheduler.step(*args, **kwargs) | |
else: | |
# Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do | |
# num_processes steps per training step | |
num_processes = AcceleratorState().num_processes | |
for _ in range(num_processes): | |
# Special case when using OneCycle and `drop_last` was not used | |
if hasattr(self.scheduler, "total_steps"): | |
if self.scheduler._step_count <= self.scheduler.total_steps: | |
self.scheduler.step(*args, **kwargs) | |
else: | |
self.scheduler.step(*args, **kwargs) | |
# Passthroughs | |
def get_last_lr(self): | |
return self.scheduler.get_last_lr() | |
def state_dict(self): | |
return self.scheduler.state_dict() | |
def load_state_dict(self, state_dict): | |
self.scheduler.load_state_dict(state_dict) | |
def get_lr(self): | |
return self.scheduler.get_lr() | |
def print_lr(self, *args, **kwargs): | |
return self.scheduler.print_lr(*args, **kwargs) | |