Spaces:
Running
Running
import itertools | |
import math | |
from copy import deepcopy | |
import warnings | |
import torch | |
from torch.nn import Module | |
from torch.optim.lr_scheduler import LRScheduler | |
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices | |
__all__ = [ | |
'AveragedModel', | |
'update_bn', | |
'SWALR', | |
'get_ema_multi_avg_fn', | |
'get_swa_multi_avg_fn', | |
'get_ema_avg_fn', | |
'get_swa_avg_fn' | |
] | |
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype | |
def get_ema_multi_avg_fn(decay=0.999): | |
def ema_update(ema_param_list, current_param_list, _): | |
# foreach lerp only handles float and complex | |
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]): | |
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay) | |
else: | |
for p_ema, p_model in zip(ema_param_list, current_param_list): | |
p_ema.copy_(p_ema * decay + p_model * (1 - decay)) | |
return ema_update | |
def get_swa_multi_avg_fn(): | |
def swa_update(averaged_param_list, current_param_list, num_averaged): | |
# foreach lerp only handles float and complex | |
if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(averaged_param_list[0]): | |
torch._foreach_lerp_(averaged_param_list, current_param_list, 1 / (num_averaged + 1)) | |
else: | |
diffs = torch._foreach_sub(current_param_list, averaged_param_list) | |
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list)) | |
return swa_update | |
def get_ema_avg_fn(decay=0.999): | |
def ema_update(ema_param, current_param, num_averaged): | |
return decay * ema_param + (1 - decay) * current_param | |
return ema_update | |
def get_swa_avg_fn(): | |
def swa_update(averaged_param, current_param, num_averaged): | |
return averaged_param + (current_param - averaged_param) / (num_averaged + 1) | |
return swa_update | |
class AveragedModel(Module): | |
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and | |
Exponential Moving Average (EMA). | |
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to | |
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii | |
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson | |
(UAI 2018). | |
Exponential Moving Average is a variation of `Polyak averaging`_, | |
but using exponential weights instead of equal weights across iterations. | |
AveragedModel class creates a copy of the provided module :attr:`model` | |
on the device :attr:`device` and allows to compute running averages of the | |
parameters of the :attr:`model`. | |
Args: | |
model (torch.nn.Module): model to use with SWA/EMA | |
device (torch.device, optional): if provided, the averaged model will be | |
stored on the :attr:`device` | |
avg_fn (function, optional): the averaging function used to update | |
parameters; the function must take in the current value of the | |
:class:`AveragedModel` parameter, the current value of :attr:`model` | |
parameter, and the number of models already averaged; if None, | |
an equally weighted average is used (default: None) | |
multi_avg_fn (function, optional): the averaging function used to update | |
parameters inplace; the function must take in the current values of the | |
:class:`AveragedModel` parameters as a list, the current values of :attr:`model` | |
parameters as a list, and the number of models already averaged; if None, | |
an equally weighted average is used (default: None) | |
use_buffers (bool): if ``True``, it will compute running averages for | |
both the parameters and the buffers of the model. (default: ``False``) | |
Example: | |
>>> # xdoctest: +SKIP("undefined variables") | |
>>> loader, optimizer, model, loss_fn = ... | |
>>> swa_model = torch.optim.swa_utils.AveragedModel(model) | |
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, | |
>>> T_max=300) | |
>>> swa_start = 160 | |
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) | |
>>> for i in range(300): | |
>>> for input, target in loader: | |
>>> optimizer.zero_grad() | |
>>> loss_fn(model(input), target).backward() | |
>>> optimizer.step() | |
>>> if i > swa_start: | |
>>> swa_model.update_parameters(model) | |
>>> swa_scheduler.step() | |
>>> else: | |
>>> scheduler.step() | |
>>> | |
>>> # Update bn statistics for the swa_model at the end | |
>>> torch.optim.swa_utils.update_bn(loader, swa_model) | |
You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters. | |
If no averaging function is provided, the default is to compute | |
equally-weighted average of the weights (SWA). | |
Example: | |
>>> # xdoctest: +SKIP("undefined variables") | |
>>> # Compute exponential moving averages of the weights and buffers | |
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, | |
>>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True) | |
.. note:: | |
When using SWA/EMA with models containing Batch Normalization you may | |
need to update the activation statistics for Batch Normalization. | |
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn` | |
or by setting :attr:`use_buffers` to `True`. The first approach updates the | |
statistics in a post-training step by passing data through the model. The | |
second does it during the parameter update phase by averaging all buffers. | |
Empirical evidence has shown that updating the statistics in normalization | |
layers increases accuracy, but you may wish to empirically test which | |
approach yields the best results in your problem. | |
.. note:: | |
:attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model. | |
.. note:: | |
When :meth:`update_parameters` is called for the first time (i.e. | |
:attr:`n_averaged` is `0`) the parameters of `model` are copied | |
to the parameters of :class:`AveragedModel`. For every subsequent | |
call of :meth:`update_parameters` the function `avg_fn` is used | |
to update the parameters. | |
.. _Averaging Weights Leads to Wider Optima and Better Generalization: | |
https://arxiv.org/abs/1803.05407 | |
.. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should | |
Average: | |
https://arxiv.org/abs/1806.05594 | |
.. _SWALP: Stochastic Weight Averaging in Low-Precision Training: | |
https://arxiv.org/abs/1904.11943 | |
.. _Stochastic Weight Averaging in Parallel: Large-Batch Training That | |
Generalizes Well: | |
https://arxiv.org/abs/2001.02312 | |
.. _Polyak averaging: | |
https://paperswithcode.com/method/polyak-averaging | |
""" | |
def __init__(self, model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False): | |
super().__init__() | |
assert avg_fn is None or multi_avg_fn is None, 'Only one of avg_fn and multi_avg_fn should be provided' | |
self.module = deepcopy(model) | |
if device is not None: | |
self.module = self.module.to(device) | |
self.register_buffer('n_averaged', | |
torch.tensor(0, dtype=torch.long, device=device)) | |
self.avg_fn = avg_fn | |
self.multi_avg_fn = multi_avg_fn | |
self.use_buffers = use_buffers | |
def forward(self, *args, **kwargs): | |
return self.module(*args, **kwargs) | |
def update_parameters(self, model): | |
self_param = ( | |
itertools.chain(self.module.parameters(), self.module.buffers()) | |
if self.use_buffers else self.parameters() | |
) | |
model_param = ( | |
itertools.chain(model.parameters(), model.buffers()) | |
if self.use_buffers else model.parameters() | |
) | |
self_param_detached = [] | |
model_param_detached = [] | |
for p_averaged, p_model in zip(self_param, model_param): | |
p_model_ = p_model.detach().to(p_averaged.device) | |
self_param_detached.append(p_averaged.detach()) | |
model_param_detached.append(p_model_) | |
if self.n_averaged == 0: | |
p_averaged.detach().copy_(p_model_) | |
if self.n_averaged > 0: | |
if self.multi_avg_fn is not None or self.avg_fn is None: | |
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached]) | |
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items(): | |
if self.multi_avg_fn: | |
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device)) | |
elif device.type in _get_foreach_kernels_supported_devices(): | |
multi_avg_fn = get_swa_multi_avg_fn() | |
multi_avg_fn(self_params, model_params, self.n_averaged.to(device)) | |
else: | |
avg_fn = get_swa_avg_fn() | |
n_averaged = self.n_averaged.to(device) | |
for p_averaged, p_model in zip(self_params, model_params): | |
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) | |
else: | |
for p_averaged, p_model in zip(self_param_detached, model_param_detached): | |
n_averaged = self.n_averaged.to(p_averaged.device) | |
p_averaged.detach().copy_(self.avg_fn(p_averaged.detach(), p_model, n_averaged)) | |
if not self.use_buffers: | |
# If not apply running averages to the buffers, | |
# keep the buffers in sync with the source model. | |
for b_swa, b_model in zip(self.module.buffers(), model.buffers()): | |
b_swa.detach().copy_(b_model.detach().to(b_swa.device)) | |
self.n_averaged += 1 | |
def update_bn(loader, model, device=None): | |
r"""Updates BatchNorm running_mean, running_var buffers in the model. | |
It performs one pass over data in `loader` to estimate the activation | |
statistics for BatchNorm layers in the model. | |
Args: | |
loader (torch.utils.data.DataLoader): dataset loader to compute the | |
activation statistics on. Each data batch should be either a | |
tensor, or a list/tuple whose first element is a tensor | |
containing data. | |
model (torch.nn.Module): model for which we seek to update BatchNorm | |
statistics. | |
device (torch.device, optional): If set, data will be transferred to | |
:attr:`device` before being passed into :attr:`model`. | |
Example: | |
>>> # xdoctest: +SKIP("Undefined variables") | |
>>> loader, model = ... | |
>>> torch.optim.swa_utils.update_bn(loader, model) | |
.. note:: | |
The `update_bn` utility assumes that each data batch in :attr:`loader` | |
is either a tensor or a list or tuple of tensors; in the latter case it | |
is assumed that :meth:`model.forward()` should be called on the first | |
element of the list or tuple corresponding to the data batch. | |
""" | |
momenta = {} | |
for module in model.modules(): | |
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): | |
module.reset_running_stats() | |
momenta[module] = module.momentum | |
if not momenta: | |
return | |
was_training = model.training | |
model.train() | |
for module in momenta.keys(): | |
module.momentum = None | |
for input in loader: | |
if isinstance(input, (list, tuple)): | |
input = input[0] | |
if device is not None: | |
input = input.to(device) | |
model(input) | |
for bn_module in momenta.keys(): | |
bn_module.momentum = momenta[bn_module] | |
model.train(was_training) | |
class SWALR(LRScheduler): | |
r"""Anneals the learning rate in each parameter group to a fixed value. | |
This learning rate scheduler is meant to be used with Stochastic Weight | |
Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). | |
Args: | |
optimizer (torch.optim.Optimizer): wrapped optimizer | |
swa_lrs (float or list): the learning rate value for all param groups | |
together or separately for each group. | |
annealing_epochs (int): number of epochs in the annealing phase | |
(default: 10) | |
annealing_strategy (str): "cos" or "linear"; specifies the annealing | |
strategy: "cos" for cosine annealing, "linear" for linear annealing | |
(default: "cos") | |
last_epoch (int): the index of the last epoch (default: -1) | |
The :class:`SWALR` scheduler can be used together with other | |
schedulers to switch to a constant learning rate late in the training | |
as in the example below. | |
Example: | |
>>> # xdoctest: +SKIP("Undefined variables") | |
>>> loader, optimizer, model = ... | |
>>> lr_lambda = lambda epoch: 0.9 | |
>>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, | |
>>> lr_lambda=lr_lambda) | |
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, | |
>>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) | |
>>> swa_start = 160 | |
>>> for i in range(300): | |
>>> for input, target in loader: | |
>>> optimizer.zero_grad() | |
>>> loss_fn(model(input), target).backward() | |
>>> optimizer.step() | |
>>> if i > swa_start: | |
>>> swa_scheduler.step() | |
>>> else: | |
>>> scheduler.step() | |
.. _Averaging Weights Leads to Wider Optima and Better Generalization: | |
https://arxiv.org/abs/1803.05407 | |
""" | |
def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1): | |
swa_lrs = self._format_param(optimizer, swa_lr) | |
for swa_lr, group in zip(swa_lrs, optimizer.param_groups): | |
group['swa_lr'] = swa_lr | |
if anneal_strategy not in ['cos', 'linear']: | |
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', " | |
f"instead got {anneal_strategy}") | |
elif anneal_strategy == 'cos': | |
self.anneal_func = self._cosine_anneal | |
elif anneal_strategy == 'linear': | |
self.anneal_func = self._linear_anneal | |
if not isinstance(anneal_epochs, int) or anneal_epochs < 0: | |
raise ValueError(f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}") | |
self.anneal_epochs = anneal_epochs | |
super().__init__(optimizer, last_epoch) | |
def _format_param(optimizer, swa_lrs): | |
if isinstance(swa_lrs, (list, tuple)): | |
if len(swa_lrs) != len(optimizer.param_groups): | |
raise ValueError("swa_lr must have the same length as " | |
f"optimizer.param_groups: swa_lr has {len(swa_lrs)}, " | |
f"optimizer.param_groups has {len(optimizer.param_groups)}") | |
return swa_lrs | |
else: | |
return [swa_lrs] * len(optimizer.param_groups) | |
def _linear_anneal(t): | |
return t | |
def _cosine_anneal(t): | |
return (1 - math.cos(math.pi * t)) / 2 | |
def _get_initial_lr(lr, swa_lr, alpha): | |
if alpha == 1: | |
return swa_lr | |
return (lr - alpha * swa_lr) / (1 - alpha) | |
def get_lr(self): | |
if not self._get_lr_called_within_step: | |
warnings.warn("To get the last learning rate computed by the scheduler, " | |
"please use `get_last_lr()`.", UserWarning) | |
step = self._step_count - 1 | |
if self.anneal_epochs == 0: | |
step = max(1, step) | |
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) | |
prev_alpha = self.anneal_func(prev_t) | |
prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha) | |
for group in self.optimizer.param_groups] | |
t = max(0, min(1, step / max(1, self.anneal_epochs))) | |
alpha = self.anneal_func(t) | |
return [group['swa_lr'] * alpha + lr * (1 - alpha) | |
for group, lr in zip(self.optimizer.param_groups, prev_lrs)] | |