File size: 4,964 Bytes
edcf5ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
"""
版本: 8月26日 17:00
SAM范化性训练,避免过拟合的优化器优化工具 ICLR 2021 spotlight paper by Google
介绍:https://mp.weixin.qq.com/s/04VT-ldd0-XEkhEW6Txl_A
第三方实现来自:https://pub.towardsai.net/we-dont-need-to-worry-about-overfitting-anymore-9fb31a154c81
论文:Sharpness-aware Minimization for Efficiently Improving Generalization
链接:https://arxiv.org/abs/2010.01412
计算原理:
在训练过程中,优化器更新模型参数w时,整体上可以分为四个步骤:
(1)基于参数 w 对 batch data S 计算 gradient G 。
(2)求解 G 的 dual norm,依照 dual vector 方向更新参数,得到 w+ε体系下的辅助模型。
(3)基于参数 w+ε 下的辅助模型,对 S 计算 gradient G’ 。
(4)用 G’ 更新原本的模型的参数 w 。
使用例子:
from sam import SAM
...
model = YourModel()
base_optimizer = torch.optim.SGD # 传入一个优化器模板
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9) # 优化器参数
...
for input, output in data:
# first forward-backward pass,计算第一轮loss,这个和普通的一样
# 第一轮的loss是真实模型跑出来的,我们统计中需要的loss就是它,第二轮loss不是真实的模型的loss(是辅助模型的),所以不需要用在传统统计loss中
output = model(input)
loss = loss_function(output, labels) # use this loss for any training statistics!!!!
loss.backward() # 模型反向传播,记录原梯度。
# step1 的SAM类计算了“SAM梯度”
optimizer.first_step(zero_grad=True) # 第一轮opt用“SAM梯度”对原模型参数体系进行了更新,现在模型变成了辅助模型,
# step1记录保存了回到原模型参数体系的变化方法
# second forward-backward pass 第二轮先对辅助模型(step1更新后的模型)正向、反向传播
output2 = model(input) # 用output2 确保计算图是辅助模型(即step1更新后的模型),不然有一堆bug。
# 由于新增了计算图,因此计算时间增加显存占用也增加?
loss_function(output2, labels).backward() # make sure to do a full forward pass 辅助模型反向传播,记录更新梯度
optimizer.second_step(zero_grad=True) # 第二轮,先原模型参数替换回去,之后base opt以辅助模型的更新方向对原模型参数体系进行更新
...
"""
import torch
class SAM(torch.optim.Optimizer):
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
@torch.no_grad()
def first_step(self, zero_grad=False): # step1 生成辅助模型,对原模型参数进行修改把他变成辅助模型,同时记录怎么变的,以便还原
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12) # 附近的梯度影响
for p in group["params"]:
if p.grad is None:
continue
e_w = p.grad * scale.to(p) # 考虑附近的梯度影响之后,确定辅助模型的参数变化需要的“SAM梯度”
p.add_(e_w) # climb to the local maximum "w + e(w)" inplace参数更新! 因此是 原模型 变成了 辅助模型
self.state[p]["e_w"] = e_w
if zero_grad:
self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False): # step2 先对辅助模型参数进行修改把他变回原模型,
# 之后对原模型基于辅助模型的梯度用base_optimizer进行参数更新
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
# 辅助模型参数还原,回到原模型 get back to "w" from "w + e(w)",注意这个也是inplace的!!
self.base_optimizer.step() # 用base_optimizer对原模型进行参数更新 do the actual "sharpness-aware" update
if zero_grad:
self.zero_grad()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device
# put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
p.grad.norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None]), p=2)
return norm
|