|
from functools import update_wrapper, wraps |
|
import torch |
|
from torch import Tensor |
|
from torch.optim.optimizer import Optimizer |
|
try: |
|
from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach |
|
has_recent_pt = True |
|
except ImportError: |
|
has_recent_pt = False |
|
|
|
from typing import List, Optional |
|
|
|
__all__ = ['SGDW', 'sgdw'] |
|
|
|
|
|
class SGDW(Optimizer): |
|
def __init__( |
|
self, |
|
params, |
|
lr=1e-3, |
|
momentum=0, |
|
dampening=0, |
|
weight_decay=0, |
|
nesterov=False, |
|
*, |
|
maximize: bool = False, |
|
foreach: Optional[bool] = None, |
|
differentiable: bool = False, |
|
): |
|
if lr < 0.0: |
|
raise ValueError(f"Invalid learning rate: {lr}") |
|
if momentum < 0.0: |
|
raise ValueError(f"Invalid momentum value: {momentum}") |
|
if weight_decay < 0.0: |
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
|
|
|
defaults = dict( |
|
lr=lr, momentum=momentum, dampening=dampening, |
|
weight_decay=weight_decay, nesterov=nesterov, |
|
maximize=maximize, foreach=foreach, |
|
differentiable=differentiable) |
|
if nesterov and (momentum <= 0 or dampening != 0): |
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening") |
|
super().__init__(params, defaults) |
|
|
|
def __setstate__(self, state): |
|
super().__setstate__(state) |
|
for group in self.param_groups: |
|
group.setdefault('nesterov', False) |
|
group.setdefault('maximize', False) |
|
group.setdefault('foreach', None) |
|
group.setdefault('differentiable', False) |
|
|
|
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): |
|
has_sparse_grad = False |
|
|
|
for p in group['params']: |
|
if p.grad is not None: |
|
params_with_grad.append(p) |
|
d_p_list.append(p.grad) |
|
if p.grad.is_sparse: |
|
has_sparse_grad = True |
|
|
|
state = self.state[p] |
|
if 'momentum_buffer' not in state: |
|
momentum_buffer_list.append(None) |
|
else: |
|
momentum_buffer_list.append(state['momentum_buffer']) |
|
|
|
return has_sparse_grad |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def step(self, closure=None): |
|
"""Performs a single optimization step. |
|
|
|
Args: |
|
closure (Callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
params_with_grad = [] |
|
d_p_list = [] |
|
momentum_buffer_list = [] |
|
|
|
has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list) |
|
|
|
sgdw( |
|
params_with_grad, |
|
d_p_list, |
|
momentum_buffer_list, |
|
weight_decay=group['weight_decay'], |
|
momentum=group['momentum'], |
|
lr=group['lr'], |
|
dampening=group['dampening'], |
|
nesterov=group['nesterov'], |
|
maximize=group['maximize'], |
|
has_sparse_grad=has_sparse_grad, |
|
foreach=group['foreach'], |
|
) |
|
|
|
|
|
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): |
|
state = self.state[p] |
|
state['momentum_buffer'] = momentum_buffer |
|
|
|
return loss |
|
|
|
|
|
def sgdw( |
|
params: List[Tensor], |
|
d_p_list: List[Tensor], |
|
momentum_buffer_list: List[Optional[Tensor]], |
|
|
|
|
|
has_sparse_grad: bool = None, |
|
foreach: Optional[bool] = None, |
|
*, |
|
weight_decay: float, |
|
momentum: float, |
|
lr: float, |
|
dampening: float, |
|
nesterov: bool, |
|
maximize: bool |
|
): |
|
r"""Functional API that performs SGD algorithm computation. |
|
|
|
See :class:`~torch.optim.SGD` for details. |
|
""" |
|
if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'): |
|
if foreach is None: |
|
|
|
|
|
if not torch.jit.is_scripting(): |
|
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) |
|
else: |
|
foreach = False |
|
|
|
if foreach and torch.jit.is_scripting(): |
|
raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
|
else: |
|
foreach = False |
|
|
|
if foreach and not torch.jit.is_scripting(): |
|
func = _multi_tensor_sgdw |
|
else: |
|
func = _single_tensor_sgdw |
|
|
|
func( |
|
params, |
|
d_p_list, |
|
momentum_buffer_list, |
|
weight_decay=weight_decay, |
|
momentum=momentum, |
|
lr=lr, |
|
dampening=dampening, |
|
nesterov=nesterov, |
|
has_sparse_grad=has_sparse_grad, |
|
maximize=maximize, |
|
) |
|
|
|
|
|
def _single_tensor_sgdw( |
|
params: List[Tensor], |
|
d_p_list: List[Tensor], |
|
momentum_buffer_list: List[Optional[Tensor]], |
|
*, |
|
weight_decay: float, |
|
momentum: float, |
|
lr: float, |
|
dampening: float, |
|
nesterov: bool, |
|
maximize: bool, |
|
has_sparse_grad: bool |
|
): |
|
for i, param in enumerate(params): |
|
d_p = d_p_list[i] if not maximize else -d_p_list[i] |
|
|
|
param.mul_(1. - lr * weight_decay) |
|
|
|
if momentum != 0: |
|
buf = momentum_buffer_list[i] |
|
|
|
if buf is None: |
|
buf = torch.clone(d_p).detach() |
|
momentum_buffer_list[i] = buf |
|
else: |
|
buf.mul_(momentum).add_(d_p, alpha=1 - dampening) |
|
|
|
if nesterov: |
|
d_p = d_p.add(buf, alpha=momentum) |
|
else: |
|
d_p = buf |
|
|
|
param.add_(d_p, alpha=-lr) |
|
|
|
|
|
def _multi_tensor_sgdw( |
|
params: List[Tensor], |
|
grads: List[Tensor], |
|
momentum_buffer_list: List[Optional[Tensor]], |
|
*, |
|
weight_decay: float, |
|
momentum: float, |
|
lr: float, |
|
dampening: float, |
|
nesterov: bool, |
|
maximize: bool, |
|
has_sparse_grad: bool |
|
): |
|
if len(params) == 0: |
|
return |
|
|
|
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( |
|
[params, grads, momentum_buffer_list], with_indices=True) |
|
for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values(): |
|
device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads) |
|
|
|
if maximize: |
|
device_grads = torch._foreach_neg(device_grads) |
|
|
|
torch._foreach_mul_(params, 1. - lr * weight_decay) |
|
|
|
if momentum != 0: |
|
bufs = [] |
|
|
|
all_states_with_momentum_buffer = True |
|
for i in range(len(device_momentum_buffer_list)): |
|
if device_momentum_buffer_list[i] is None: |
|
all_states_with_momentum_buffer = False |
|
break |
|
else: |
|
bufs.append(device_momentum_buffer_list[i]) |
|
|
|
if all_states_with_momentum_buffer: |
|
torch._foreach_mul_(bufs, momentum) |
|
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) |
|
else: |
|
bufs = [] |
|
for i in range(len(device_momentum_buffer_list)): |
|
if device_momentum_buffer_list[i] is None: |
|
buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \ |
|
torch.clone(device_grads[i]).detach() |
|
else: |
|
buf = device_momentum_buffer_list[i] |
|
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) |
|
|
|
bufs.append(buf) |
|
|
|
if nesterov: |
|
torch._foreach_add_(device_grads, bufs, alpha=momentum) |
|
else: |
|
device_grads = bufs |
|
|
|
if not device_has_sparse_grad: |
|
torch._foreach_add_(device_params, device_grads, alpha=-lr) |
|
else: |
|
|
|
for i in range(len(device_params)): |
|
device_params[i].add_(device_grads[i], alpha=-lr) |
|
|