Upload opimizer.py
Browse files- opimizer.py +229 -0
opimizer.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class MaxFactor(torch.optim.Optimizer):
|
4 |
+
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0,
|
5 |
+
weight_decay=0.01, gamma=0.99, max=False):
|
6 |
+
|
7 |
+
defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
|
8 |
+
gamma=gamma, max=max)
|
9 |
+
super().__init__(params=params, defaults=defaults)
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def _rms(tensor):
|
13 |
+
return tensor.norm() / (tensor.numel() ** 0.5)
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def step(self, closure=None):
|
17 |
+
loss = None
|
18 |
+
if closure is not None:
|
19 |
+
with torch.enable_grad():
|
20 |
+
loss = closure()
|
21 |
+
|
22 |
+
for group in self.param_groups:
|
23 |
+
params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
|
24 |
+
eps1, eps2 = group["eps"]
|
25 |
+
for p in group["params"]:
|
26 |
+
if p.grad is None:
|
27 |
+
continue
|
28 |
+
grad = p.grad
|
29 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
30 |
+
grad = grad.float()
|
31 |
+
|
32 |
+
state = self.state[p]
|
33 |
+
if len(state) == 0:
|
34 |
+
state["step"] = torch.tensor(0.0, dtype=torch.float32)
|
35 |
+
if p.grad.dim() > 1:
|
36 |
+
row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
|
37 |
+
row_shape[-1], col_shape[-2] = 1, 1
|
38 |
+
state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
|
39 |
+
state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
40 |
+
state["RMS"] = self._rms(p).item()
|
41 |
+
|
42 |
+
row_vars.append(state.get("row_var", None))
|
43 |
+
col_vars.append(state.get("col_var", None))
|
44 |
+
v.append(state["v"])
|
45 |
+
state_steps.append(state["step"])
|
46 |
+
params_with_grad.append(p)
|
47 |
+
grads.append(grad)
|
48 |
+
|
49 |
+
for i, param in enumerate(params_with_grad):
|
50 |
+
grad = grads[i]
|
51 |
+
|
52 |
+
if group["max"]:
|
53 |
+
grad = -grad
|
54 |
+
step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]
|
55 |
+
|
56 |
+
if eps1 is None:
|
57 |
+
eps1 = torch.finfo(param.dtype).eps
|
58 |
+
|
59 |
+
step_t += 1
|
60 |
+
step_float = step_t.item()
|
61 |
+
|
62 |
+
one_minus_beta2_t = step_float ** group["beta2_decay"]
|
63 |
+
state["RMS"] = self._rms(param).item()
|
64 |
+
|
65 |
+
rho_t = min(group["lr"], 1 / (step_float ** 0.5))
|
66 |
+
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
|
67 |
+
|
68 |
+
if group["weight_decay"] != 0:
|
69 |
+
param.mul_(1 - group["lr"] * group["weight_decay"])
|
70 |
+
|
71 |
+
if grad.dim() > 1:
|
72 |
+
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)
|
73 |
+
row_var.lerp_(row_mean, one_minus_beta2_t)
|
74 |
+
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)
|
75 |
+
col_var.lerp_(col_mean, one_minus_beta2_t)
|
76 |
+
var_estimate = row_var @ col_var
|
77 |
+
max_row_var = row_var.max(dim=-2, keepdim=True)[0]
|
78 |
+
var_estimate.div_(max_row_var.clamp_(min=eps1))
|
79 |
+
else:
|
80 |
+
vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"])
|
81 |
+
var_estimate = vi
|
82 |
+
|
83 |
+
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
|
84 |
+
update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
|
85 |
+
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
|
86 |
+
|
87 |
+
param.add_(-alpha / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
|
88 |
+
return loss
|
89 |
+
|
90 |
+
# class MaxFactor(torch.optim.Optimizer):
|
91 |
+
# __version__ = "1.0"
|
92 |
+
|
93 |
+
# def __init__(self, params, lr=0.025, beta2_decay=-0.8, eps=(1e-10, 1e-4), d=1.0,
|
94 |
+
# weight_decay=0.025, gamma=0.99, max=False, min_lr=1e-7):
|
95 |
+
|
96 |
+
# print(f"Using MaxFactor optimizer v{self.__version__}")
|
97 |
+
|
98 |
+
# defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
|
99 |
+
# gamma=gamma, max=max, min_lr=min_lr)
|
100 |
+
# super().__init__(params=params, defaults=defaults)
|
101 |
+
|
102 |
+
# def get_lr(self):
|
103 |
+
# """Return current learning rates for all parameter groups."""
|
104 |
+
# param_specific_lrs = []
|
105 |
+
|
106 |
+
# for group in self.param_groups:
|
107 |
+
# group_lrs = []
|
108 |
+
# min_lr = group.get("min_lr", 1e-7)
|
109 |
+
# eps1, eps2 = group["eps"]
|
110 |
+
# for p in group["params"]:
|
111 |
+
# if p.grad is None:
|
112 |
+
# continue
|
113 |
+
# state = self.state[p]
|
114 |
+
# if "step" not in state:
|
115 |
+
# continue
|
116 |
+
# step_float = state["step"].item()
|
117 |
+
# # Calculate base learning rate (same as in step method)
|
118 |
+
# rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
|
119 |
+
|
120 |
+
# # Calculate parameter-specific scaling
|
121 |
+
# param_norm = (p.norm() / (p.numel() ** 0.5 + 1e-12)).item()
|
122 |
+
# alpha = max(eps2, param_norm) * rho_t
|
123 |
+
# group_lrs.append(alpha)
|
124 |
+
# if group_lrs:
|
125 |
+
# param_specific_lrs.append(sum(group_lrs) / len(group_lrs))
|
126 |
+
# else:
|
127 |
+
# param_specific_lrs.append(group["lr"])
|
128 |
+
# return param_specific_lrs
|
129 |
+
|
130 |
+
# def get_last_lr(self):
|
131 |
+
# return self.get_lr()
|
132 |
+
|
133 |
+
# @torch.no_grad()
|
134 |
+
# def step(self, closure=None):
|
135 |
+
# loss = None
|
136 |
+
# if closure is not None:
|
137 |
+
# with torch.enable_grad():
|
138 |
+
# loss = closure()
|
139 |
+
|
140 |
+
# for group in self.param_groups:
|
141 |
+
# params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
|
142 |
+
# eps1, eps2 = group["eps"]
|
143 |
+
# min_lr = group.get("min_lr", 1e-7)
|
144 |
+
|
145 |
+
# for p in group["params"]:
|
146 |
+
# if p.grad is None:
|
147 |
+
# continue
|
148 |
+
|
149 |
+
# grad = p.grad
|
150 |
+
# if grad.dtype in {torch.float16, torch.bfloat16}:
|
151 |
+
# grad = grad.float()
|
152 |
+
|
153 |
+
# state = self.state[p]
|
154 |
+
# if len(state) == 0:
|
155 |
+
# state["step"] = torch.tensor(0.0, dtype=torch.float32)
|
156 |
+
# if p.dim() > 1:
|
157 |
+
# row_shape, col_shape = list(p.shape), list(p.shape)
|
158 |
+
# row_shape[-1], col_shape[-2] = 1, 1
|
159 |
+
# state["row_var"] = p.new_zeros(row_shape)
|
160 |
+
# state["col_var"] = p.new_zeros(col_shape)
|
161 |
+
# state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
162 |
+
|
163 |
+
# row_vars.append(state.get("row_var", None))
|
164 |
+
# col_vars.append(state.get("col_var", None))
|
165 |
+
# v.append(state["v"])
|
166 |
+
# state_steps.append(state["step"])
|
167 |
+
# params_with_grad.append(p)
|
168 |
+
# grads.append(grad)
|
169 |
+
|
170 |
+
# for i, param in enumerate(params_with_grad):
|
171 |
+
# grad = grads[i]
|
172 |
+
# state = self.state[param]
|
173 |
+
|
174 |
+
# if group["max"]:
|
175 |
+
# grad = -grad
|
176 |
+
|
177 |
+
# step_t = state_steps[i]
|
178 |
+
# row_var, col_var, vi = row_vars[i], col_vars[i], v[i]
|
179 |
+
|
180 |
+
# if eps1 is None:
|
181 |
+
# eps1 = torch.finfo(param.dtype).eps
|
182 |
+
|
183 |
+
# step_t += 1
|
184 |
+
# step_float = step_t.item()
|
185 |
+
|
186 |
+
# one_minus_beta2_t = min(0.999, max(0.001, step_float ** group["beta2_decay"]))
|
187 |
+
|
188 |
+
# rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5)))
|
189 |
+
# alpha = max(eps2, (param.norm() / (param.numel() ** 0.5 + 1e-12)).item()) * rho_t
|
190 |
+
|
191 |
+
# if group["weight_decay"] > 0:
|
192 |
+
# param.mul_(1 - group["lr"] * group["weight_decay"])
|
193 |
+
|
194 |
+
# if grad.dim() > 1:
|
195 |
+
# row_mean = torch.norm(grad, dim=-1, keepdim=True).square_()
|
196 |
+
# row_mean.div_(grad.size(-1) + eps1)
|
197 |
+
|
198 |
+
# row_var.lerp_(row_mean, one_minus_beta2_t)
|
199 |
+
|
200 |
+
# col_mean = torch.norm(grad, dim=-2, keepdim=True).square_()
|
201 |
+
# col_mean.div_(grad.size(-2) + eps1)
|
202 |
+
|
203 |
+
# col_var.lerp_(col_mean, one_minus_beta2_t)
|
204 |
+
|
205 |
+
# var_estimate = row_var @ col_var
|
206 |
+
# max_row_var = row_var.max(dim=-2, keepdim=True)[0]
|
207 |
+
# var_estimate.div_(max_row_var.clamp_(min=eps1))
|
208 |
+
# else:
|
209 |
+
# vi.mul_(group["gamma"]).add_(grad.square_(), alpha=1 - group["gamma"])
|
210 |
+
# var_estimate = vi
|
211 |
+
|
212 |
+
# update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
|
213 |
+
|
214 |
+
# inf_norm = torch.norm(update, float('inf'))
|
215 |
+
# if inf_norm > 0:
|
216 |
+
# update.div_(inf_norm.clamp_(min=eps1))
|
217 |
+
|
218 |
+
# denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
|
219 |
+
|
220 |
+
# if param.dim() > 1:
|
221 |
+
# max_vals = update.abs().max(dim=-1, keepdim=True)[0]
|
222 |
+
# param.add_(-alpha / denom * update.sign() * max_vals)
|
223 |
+
# else:
|
224 |
+
# param.add_(-alpha / denom * update)
|
225 |
+
|
226 |
+
# state["step"] = step_t
|
227 |
+
|
228 |
+
# return loss
|
229 |
+
|