Sin2pi commited on
Commit
6067697
·
verified ·
1 Parent(s): ebb8a2a

Upload opimizer.py

Browse files
Files changed (1) hide show
  1. opimizer.py +229 -0
opimizer.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class MaxFactor(torch.optim.Optimizer):
4
+ def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0,
5
+ weight_decay=0.01, gamma=0.99, max=False):
6
+
7
+ defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
8
+ gamma=gamma, max=max)
9
+ super().__init__(params=params, defaults=defaults)
10
+
11
+ @staticmethod
12
+ def _rms(tensor):
13
+ return tensor.norm() / (tensor.numel() ** 0.5)
14
+
15
+ @torch.no_grad()
16
+ def step(self, closure=None):
17
+ loss = None
18
+ if closure is not None:
19
+ with torch.enable_grad():
20
+ loss = closure()
21
+
22
+ for group in self.param_groups:
23
+ params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
24
+ eps1, eps2 = group["eps"]
25
+ for p in group["params"]:
26
+ if p.grad is None:
27
+ continue
28
+ grad = p.grad
29
+ if grad.dtype in {torch.float16, torch.bfloat16}:
30
+ grad = grad.float()
31
+
32
+ state = self.state[p]
33
+ if len(state) == 0:
34
+ state["step"] = torch.tensor(0.0, dtype=torch.float32)
35
+ if p.grad.dim() > 1:
36
+ row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
37
+ row_shape[-1], col_shape[-2] = 1, 1
38
+ state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
39
+ state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
40
+ state["RMS"] = self._rms(p).item()
41
+
42
+ row_vars.append(state.get("row_var", None))
43
+ col_vars.append(state.get("col_var", None))
44
+ v.append(state["v"])
45
+ state_steps.append(state["step"])
46
+ params_with_grad.append(p)
47
+ grads.append(grad)
48
+
49
+ for i, param in enumerate(params_with_grad):
50
+ grad = grads[i]
51
+
52
+ if group["max"]:
53
+ grad = -grad
54
+ step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]
55
+
56
+ if eps1 is None:
57
+ eps1 = torch.finfo(param.dtype).eps
58
+
59
+ step_t += 1
60
+ step_float = step_t.item()
61
+
62
+ one_minus_beta2_t = step_float ** group["beta2_decay"]
63
+ state["RMS"] = self._rms(param).item()
64
+
65
+ rho_t = min(group["lr"], 1 / (step_float ** 0.5))
66
+ alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
67
+
68
+ if group["weight_decay"] != 0:
69
+ param.mul_(1 - group["lr"] * group["weight_decay"])
70
+
71
+ if grad.dim() > 1:
72
+ row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)
73
+ row_var.lerp_(row_mean, one_minus_beta2_t)
74
+ col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)
75
+ col_var.lerp_(col_mean, one_minus_beta2_t)
76
+ var_estimate = row_var @ col_var
77
+ max_row_var = row_var.max(dim=-2, keepdim=True)[0]
78
+ var_estimate.div_(max_row_var.clamp_(min=eps1))
79
+ else:
80
+ vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"])
81
+ var_estimate = vi
82
+
83
+ update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
84
+ update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
85
+ denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
86
+
87
+ param.add_(-alpha / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
88
+ return loss
89
+
90
+ # class MaxFactor(torch.optim.Optimizer):
91
+ # __version__ = "1.0"
92
+
93
+ # def __init__(self, params, lr=0.025, beta2_decay=-0.8, eps=(1e-10, 1e-4), d=1.0,
94
+ # weight_decay=0.025, gamma=0.99, max=False, min_lr=1e-7):
95
+
96
+ # print(f"Using MaxFactor optimizer v{self.__version__}")
97
+
98
+ # defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
99
+ # gamma=gamma, max=max, min_lr=min_lr)
100
+ # super().__init__(params=params, defaults=defaults)
101
+
102
+ # def get_lr(self):
103
+ # """Return current learning rates for all parameter groups."""
104
+ # param_specific_lrs = []
105
+
106
+ # for group in self.param_groups:
107
+ # group_lrs = []
108
+ # min_lr = group.get("min_lr", 1e-7)
109
+ # eps1, eps2 = group["eps"]
110
+ # for p in group["params"]:
111
+ # if p.grad is None:
112
+ # continue
113
+ # state = self.state[p]
114
+ # if "step" not in state:
115
+ # continue
116
+ # step_float = state["step"].item()
117
+ # # Calculate base learning rate (same as in step method)
118
+ # rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
119
+
120
+ # # Calculate parameter-specific scaling
121
+ # param_norm = (p.norm() / (p.numel() ** 0.5 + 1e-12)).item()
122
+ # alpha = max(eps2, param_norm) * rho_t
123
+ # group_lrs.append(alpha)
124
+ # if group_lrs:
125
+ # param_specific_lrs.append(sum(group_lrs) / len(group_lrs))
126
+ # else:
127
+ # param_specific_lrs.append(group["lr"])
128
+ # return param_specific_lrs
129
+
130
+ # def get_last_lr(self):
131
+ # return self.get_lr()
132
+
133
+ # @torch.no_grad()
134
+ # def step(self, closure=None):
135
+ # loss = None
136
+ # if closure is not None:
137
+ # with torch.enable_grad():
138
+ # loss = closure()
139
+
140
+ # for group in self.param_groups:
141
+ # params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
142
+ # eps1, eps2 = group["eps"]
143
+ # min_lr = group.get("min_lr", 1e-7)
144
+
145
+ # for p in group["params"]:
146
+ # if p.grad is None:
147
+ # continue
148
+
149
+ # grad = p.grad
150
+ # if grad.dtype in {torch.float16, torch.bfloat16}:
151
+ # grad = grad.float()
152
+
153
+ # state = self.state[p]
154
+ # if len(state) == 0:
155
+ # state["step"] = torch.tensor(0.0, dtype=torch.float32)
156
+ # if p.dim() > 1:
157
+ # row_shape, col_shape = list(p.shape), list(p.shape)
158
+ # row_shape[-1], col_shape[-2] = 1, 1
159
+ # state["row_var"] = p.new_zeros(row_shape)
160
+ # state["col_var"] = p.new_zeros(col_shape)
161
+ # state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
162
+
163
+ # row_vars.append(state.get("row_var", None))
164
+ # col_vars.append(state.get("col_var", None))
165
+ # v.append(state["v"])
166
+ # state_steps.append(state["step"])
167
+ # params_with_grad.append(p)
168
+ # grads.append(grad)
169
+
170
+ # for i, param in enumerate(params_with_grad):
171
+ # grad = grads[i]
172
+ # state = self.state[param]
173
+
174
+ # if group["max"]:
175
+ # grad = -grad
176
+
177
+ # step_t = state_steps[i]
178
+ # row_var, col_var, vi = row_vars[i], col_vars[i], v[i]
179
+
180
+ # if eps1 is None:
181
+ # eps1 = torch.finfo(param.dtype).eps
182
+
183
+ # step_t += 1
184
+ # step_float = step_t.item()
185
+
186
+ # one_minus_beta2_t = min(0.999, max(0.001, step_float ** group["beta2_decay"]))
187
+
188
+ # rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
189
+ # alpha = max(eps2, (param.norm() / (param.numel() ** 0.5 + 1e-12)).item()) * rho_t
190
+
191
+ # if group["weight_decay"] > 0:
192
+ # param.mul_(1 - group["lr"] * group["weight_decay"])
193
+
194
+ # if grad.dim() > 1:
195
+ # row_mean = torch.norm(grad, dim=-1, keepdim=True).square_()
196
+ # row_mean.div_(grad.size(-1) + eps1)
197
+
198
+ # row_var.lerp_(row_mean, one_minus_beta2_t)
199
+
200
+ # col_mean = torch.norm(grad, dim=-2, keepdim=True).square_()
201
+ # col_mean.div_(grad.size(-2) + eps1)
202
+
203
+ # col_var.lerp_(col_mean, one_minus_beta2_t)
204
+
205
+ # var_estimate = row_var @ col_var
206
+ # max_row_var = row_var.max(dim=-2, keepdim=True)[0]
207
+ # var_estimate.div_(max_row_var.clamp_(min=eps1))
208
+ # else:
209
+ # vi.mul_(group["gamma"]).add_(grad.square_(), alpha=1 - group["gamma"])
210
+ # var_estimate = vi
211
+
212
+ # update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
213
+
214
+ # inf_norm = torch.norm(update, float('inf'))
215
+ # if inf_norm > 0:
216
+ # update.div_(inf_norm.clamp_(min=eps1))
217
+
218
+ # denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
219
+
220
+ # if param.dim() > 1:
221
+ # max_vals = update.abs().max(dim=-1, keepdim=True)[0]
222
+ # param.add_(-alpha / denom * update.sign() * max_vals)
223
+ # else:
224
+ # param.add_(-alpha / denom * update)
225
+
226
+ # state["step"] = step_t
227
+
228
+ # return loss
229
+