File size: 5,256 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
""" PyTorch LARS / LARC Optimizer
An implementation of LARS (SGD) + LARC in PyTorch
Based on:
* PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
* NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
Additional cleanup and modifications to properly support PyTorch XLA.
Copyright 2021 Ross Wightman
"""
import torch
from torch.optim.optimizer import Optimizer
class Lars(Optimizer):
""" LARS for PyTorch
Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate (default: 1.0).
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
eps (float): eps for division denominator (default: 1e-8)
trust_clip (bool): enable LARC trust ratio clipping (default: False)
always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
"""
def __init__(
self,
params,
lr=1.0,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
trust_coeff=0.001,
eps=1e-8,
trust_clip=False,
always_adapt=False,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
trust_coeff=trust_coeff,
eps=eps,
trust_clip=trust_clip,
always_adapt=always_adapt,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
trust_coeff = group['trust_coeff']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
# apply LARS LR adaptation, LARC clipping, weight decay
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
if weight_decay != 0 or group['always_adapt']:
w_norm = p.norm(2.0)
g_norm = grad.norm(2.0)
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, trust_ratio, one_tensor),
one_tensor,
)
if group['trust_clip']:
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
grad.add_(p, alpha=weight_decay)
grad.mul_(trust_ratio)
# apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
p.add_(grad, alpha=-group['lr'])
return loss |