|
"""
|
|
版本: 8月26日 17:00
|
|
SAM范化性训练,避免过拟合的优化器优化工具 ICLR 2021 spotlight paper by Google
|
|
|
|
介绍:https://mp.weixin.qq.com/s/04VT-ldd0-XEkhEW6Txl_A
|
|
第三方实现来自:https://pub.towardsai.net/we-dont-need-to-worry-about-overfitting-anymore-9fb31a154c81
|
|
|
|
论文:Sharpness-aware Minimization for Efficiently Improving Generalization
|
|
链接:https://arxiv.org/abs/2010.01412
|
|
|
|
计算原理:
|
|
在训练过程中,优化器更新模型参数w时,整体上可以分为四个步骤:
|
|
|
|
(1)基于参数 w 对 batch data S 计算 gradient G 。
|
|
|
|
(2)求解 G 的 dual norm,依照 dual vector 方向更新参数,得到 w+ε体系下的辅助模型。
|
|
|
|
(3)基于参数 w+ε 下的辅助模型,对 S 计算 gradient G’ 。
|
|
|
|
(4)用 G’ 更新原本的模型的参数 w 。
|
|
|
|
|
|
使用例子:
|
|
from sam import SAM
|
|
...
|
|
model = YourModel()
|
|
base_optimizer = torch.optim.SGD # 传入一个优化器模板
|
|
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9) # 优化器参数
|
|
...
|
|
for input, output in data:
|
|
|
|
# first forward-backward pass,计算第一轮loss,这个和普通的一样
|
|
# 第一轮的loss是真实模型跑出来的,我们统计中需要的loss就是它,第二轮loss不是真实的模型的loss(是辅助模型的),所以不需要用在传统统计loss中
|
|
output = model(input)
|
|
loss = loss_function(output, labels) # use this loss for any training statistics!!!!
|
|
loss.backward() # 模型反向传播,记录原梯度。
|
|
|
|
# step1 的SAM类计算了“SAM梯度”
|
|
optimizer.first_step(zero_grad=True) # 第一轮opt用“SAM梯度”对原模型参数体系进行了更新,现在模型变成了辅助模型,
|
|
# step1记录保存了回到原模型参数体系的变化方法
|
|
|
|
# second forward-backward pass 第二轮先对辅助模型(step1更新后的模型)正向、反向传播
|
|
output2 = model(input) # 用output2 确保计算图是辅助模型(即step1更新后的模型),不然有一堆bug。
|
|
|
|
# 由于新增了计算图,因此计算时间增加显存占用也增加?
|
|
|
|
loss_function(output2, labels).backward() # make sure to do a full forward pass 辅助模型反向传播,记录更新梯度
|
|
optimizer.second_step(zero_grad=True) # 第二轮,先原模型参数替换回去,之后base opt以辅助模型的更新方向对原模型参数体系进行更新
|
|
...
|
|
|
|
|
|
"""
|
|
import torch
|
|
|
|
|
|
class SAM(torch.optim.Optimizer):
|
|
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
|
|
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
|
|
|
defaults = dict(rho=rho, **kwargs)
|
|
super(SAM, self).__init__(params, defaults)
|
|
|
|
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
|
self.param_groups = self.base_optimizer.param_groups
|
|
|
|
@torch.no_grad()
|
|
def first_step(self, zero_grad=False):
|
|
grad_norm = self._grad_norm()
|
|
|
|
for group in self.param_groups:
|
|
scale = group["rho"] / (grad_norm + 1e-12)
|
|
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
e_w = p.grad * scale.to(p)
|
|
p.add_(e_w)
|
|
self.state[p]["e_w"] = e_w
|
|
|
|
if zero_grad:
|
|
self.zero_grad()
|
|
|
|
@torch.no_grad()
|
|
def second_step(self, zero_grad=False):
|
|
|
|
|
|
for group in self.param_groups:
|
|
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
p.sub_(self.state[p]["e_w"])
|
|
|
|
|
|
self.base_optimizer.step()
|
|
|
|
if zero_grad:
|
|
self.zero_grad()
|
|
|
|
def _grad_norm(self):
|
|
shared_device = self.param_groups[0]["params"][0].device
|
|
|
|
norm = torch.norm(
|
|
torch.stack([
|
|
p.grad.norm(p=2).to(shared_device)
|
|
for group in self.param_groups for p in group["params"]
|
|
if p.grad is not None]), p=2)
|
|
return norm
|
|
|