File size: 4,964 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""

版本: 8月26日 17:00

SAM范化性训练,避免过拟合的优化器优化工具  ICLR 2021 spotlight paper by Google



介绍:https://mp.weixin.qq.com/s/04VT-ldd0-XEkhEW6Txl_A

第三方实现来自:https://pub.towardsai.net/we-dont-need-to-worry-about-overfitting-anymore-9fb31a154c81



论文:Sharpness-aware Minimization for Efficiently Improving Generalization

链接:https://arxiv.org/abs/2010.01412



计算原理:

在训练过程中,优化器更新模型参数w时,整体上可以分为四个步骤:



(1)基于参数 w 对 batch data S 计算 gradient G 。



(2)求解 G 的 dual norm,依照 dual vector 方向更新参数,得到 w+ε体系下的辅助模型。



(3)基于参数 w+ε 下的辅助模型,对 S 计算 gradient G’ 。



(4)用 G’ 更新原本的模型的参数 w 。





使用例子:

from sam import SAM

...

model = YourModel()

base_optimizer = torch.optim.SGD  # 传入一个优化器模板

optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)  # 优化器参数

...

for input, output in data:



  # first forward-backward pass,计算第一轮loss,这个和普通的一样

  # 第一轮的loss是真实模型跑出来的,我们统计中需要的loss就是它,第二轮loss不是真实的模型的loss(是辅助模型的),所以不需要用在传统统计loss中

  output = model(input)

  loss = loss_function(output, labels)  # use this loss for any training statistics!!!!

  loss.backward()  # 模型反向传播,记录原梯度。



  # step1 的SAM类计算了“SAM梯度”

  optimizer.first_step(zero_grad=True)  # 第一轮opt用“SAM梯度”对原模型参数体系进行了更新,现在模型变成了辅助模型,

  # step1记录保存了回到原模型参数体系的变化方法



  # second forward-backward pass  第二轮先对辅助模型(step1更新后的模型)正向、反向传播

  output2 = model(input)  # 用output2 确保计算图是辅助模型(即step1更新后的模型),不然有一堆bug。



  # 由于新增了计算图,因此计算时间增加显存占用也增加?



  loss_function(output2, labels).backward()  # make sure to do a full forward pass 辅助模型反向传播,记录更新梯度

  optimizer.second_step(zero_grad=True)  # 第二轮,先原模型参数替换回去,之后base opt以辅助模型的更新方向对原模型参数体系进行更新

...





"""
import torch


class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):  # step1 生成辅助模型,对原模型参数进行修改把他变成辅助模型,同时记录怎么变的,以便还原
        grad_norm = self._grad_norm()

        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)  # 附近的梯度影响

            for p in group["params"]:
                if p.grad is None:
                    continue
                e_w = p.grad * scale.to(p)  # 考虑附近的梯度影响之后,确定辅助模型的参数变化需要的“SAM梯度”
                p.add_(e_w)  # climb to the local maximum "w + e(w)"  inplace参数更新! 因此是 原模型 变成了 辅助模型
                self.state[p]["e_w"] = e_w

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):  # step2 先对辅助模型参数进行修改把他变回原模型,
        # 之后对原模型基于辅助模型的梯度用base_optimizer进行参数更新

        for group in self.param_groups:

            for p in group["params"]:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
                # 辅助模型参数还原,回到原模型 get back to "w" from "w + e(w)",注意这个也是inplace的!!

        self.base_optimizer.step()  # 用base_optimizer对原模型进行参数更新 do the actual "sharpness-aware" update

        if zero_grad:
            self.zero_grad()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        # put everything on the same device, in case of model parallelism
        norm = torch.norm(
            torch.stack([
                p.grad.norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None]), p=2)
        return norm