|
""" Adaptive Gradient Clipping |
|
|
|
An impl of AGC, as per (https://arxiv.org/abs/2102.06171): |
|
|
|
@article{brock2021high, |
|
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, |
|
title={High-Performance Large-Scale Image Recognition Without Normalization}, |
|
journal={arXiv preprint arXiv:}, |
|
year={2021} |
|
} |
|
|
|
Code references: |
|
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets |
|
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c |
|
|
|
Hacked together by / Copyright 2021 Ross Wightman |
|
""" |
|
import torch |
|
|
|
|
|
def unitwise_norm(x, norm_type=2.0): |
|
if x.ndim <= 1: |
|
return x.norm(norm_type) |
|
else: |
|
|
|
|
|
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) |
|
|
|
|
|
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): |
|
if isinstance(parameters, torch.Tensor): |
|
parameters = [parameters] |
|
for p in parameters: |
|
if p.grad is None: |
|
continue |
|
p_data = p.detach() |
|
g_data = p.grad.detach() |
|
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) |
|
grad_norm = unitwise_norm(g_data, norm_type=norm_type) |
|
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) |
|
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) |
|
p.grad.detach().copy_(new_grads) |
|
|