import torch | |
class MaxFactor(torch.optim.Optimizer): | |
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0, | |
weight_decay=0.01, gamma=0.99, max=False): | |
defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay, | |
gamma=gamma, max=max) | |
super().__init__(params=params, defaults=defaults) | |
def _rms(tensor): | |
return tensor.norm() / (tensor.numel() ** 0.5) | |
def step(self, closure=None): | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], [] | |
eps1, eps2 = group["eps"] | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
grad = p.grad | |
if grad.dtype in {torch.float16, torch.bfloat16}: | |
grad = grad.float() | |
state = self.state[p] | |
if len(state) == 0: | |
state["step"] = torch.tensor(0.0, dtype=torch.float32) | |
if p.grad.dim() > 1: | |
row_shape, col_shape = list(p.grad.shape), list(p.grad.shape) | |
row_shape[-1], col_shape[-2] = 1, 1 | |
state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape) | |
state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
state["RMS"] = self._rms(p).item() | |
row_vars.append(state.get("row_var", None)) | |
col_vars.append(state.get("col_var", None)) | |
v.append(state["v"]) | |
state_steps.append(state["step"]) | |
params_with_grad.append(p) | |
grads.append(grad) | |
for i, param in enumerate(params_with_grad): | |
grad = grads[i] | |
if group["max"]: | |
grad = -grad | |
step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i] | |
if eps1 is None: | |
eps1 = torch.finfo(param.dtype).eps | |
step_t += 1 | |
step_float = step_t.item() | |
one_minus_beta2_t = step_float ** group["beta2_decay"] | |
state["RMS"] = self._rms(param).item() | |
rho_t = min(group["lr"], 1 / (step_float ** 0.5)) | |
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t | |
if group["weight_decay"] != 0: | |
param.mul_(1 - group["lr"] * group["weight_decay"]) | |
if grad.dim() > 1: | |
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8) | |
row_var.lerp_(row_mean, one_minus_beta2_t) | |
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8) | |
col_var.lerp_(col_mean, one_minus_beta2_t) | |
var_estimate = row_var @ col_var | |
max_row_var = row_var.max(dim=-2, keepdim=True)[0] | |
var_estimate.div_(max_row_var.clamp_(min=eps1)) | |
else: | |
vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"]) | |
var_estimate = vi | |
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad) | |
update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1)) | |
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"])) | |
param.add_(-alpha / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0]) | |
return loss | |
# class MaxFactor(torch.optim.Optimizer): | |
# __version__ = "1.0" | |
# def __init__(self, params, lr=0.025, beta2_decay=-0.8, eps=(1e-10, 1e-4), d=1.0, | |
# weight_decay=0.025, gamma=0.99, max=False, min_lr=1e-7): | |
# print(f"Using MaxFactor optimizer v{self.__version__}") | |
# defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay, | |
# gamma=gamma, max=max, min_lr=min_lr) | |
# super().__init__(params=params, defaults=defaults) | |
# def get_lr(self): | |
# """Return current learning rates for all parameter groups.""" | |
# param_specific_lrs = [] | |
# for group in self.param_groups: | |
# group_lrs = [] | |
# min_lr = group.get("min_lr", 1e-7) | |
# eps1, eps2 = group["eps"] | |
# for p in group["params"]: | |
# if p.grad is None: | |
# continue | |
# state = self.state[p] | |
# if "step" not in state: | |
# continue | |
# step_float = state["step"].item() | |
# # Calculate base learning rate (same as in step method) | |
# rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5))) | |
# # Calculate parameter-specific scaling | |
# param_norm = (p.norm() / (p.numel() ** 0.5 + 1e-12)).item() | |
# alpha = max(eps2, param_norm) * rho_t | |
# group_lrs.append(alpha) | |
# if group_lrs: | |
# param_specific_lrs.append(sum(group_lrs) / len(group_lrs)) | |
# else: | |
# param_specific_lrs.append(group["lr"]) | |
# return param_specific_lrs | |
# def get_last_lr(self): | |
# return self.get_lr() | |
# @torch.no_grad() | |
# def step(self, closure=None): | |
# loss = None | |
# if closure is not None: | |
# with torch.enable_grad(): | |
# loss = closure() | |
# for group in self.param_groups: | |
# params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], [] | |
# eps1, eps2 = group["eps"] | |
# min_lr = group.get("min_lr", 1e-7) | |
# for p in group["params"]: | |
# if p.grad is None: | |
# continue | |
# grad = p.grad | |
# if grad.dtype in {torch.float16, torch.bfloat16}: | |
# grad = grad.float() | |
# state = self.state[p] | |
# if len(state) == 0: | |
# state["step"] = torch.tensor(0.0, dtype=torch.float32) | |
# if p.dim() > 1: | |
# row_shape, col_shape = list(p.shape), list(p.shape) | |
# row_shape[-1], col_shape[-2] = 1, 1 | |
# state["row_var"] = p.new_zeros(row_shape) | |
# state["col_var"] = p.new_zeros(col_shape) | |
# state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
# row_vars.append(state.get("row_var", None)) | |
# col_vars.append(state.get("col_var", None)) | |
# v.append(state["v"]) | |
# state_steps.append(state["step"]) | |
# params_with_grad.append(p) | |
# grads.append(grad) | |
# for i, param in enumerate(params_with_grad): | |
# grad = grads[i] | |
# state = self.state[param] | |
# if group["max"]: | |
# grad = -grad | |
# step_t = state_steps[i] | |
# row_var, col_var, vi = row_vars[i], col_vars[i], v[i] | |
# if eps1 is None: | |
# eps1 = torch.finfo(param.dtype).eps | |
# step_t += 1 | |
# step_float = step_t.item() | |
# one_minus_beta2_t = min(0.999, max(0.001, step_float ** group["beta2_decay"])) | |
# rho_t = max(min_lr, min(group["lr"], 1.0 / (step_float ** 0.5))) | |
# alpha = max(eps2, (param.norm() / (param.numel() ** 0.5 + 1e-12)).item()) * rho_t | |
# if group["weight_decay"] > 0: | |
# param.mul_(1 - group["lr"] * group["weight_decay"]) | |
# if grad.dim() > 1: | |
# row_mean = torch.norm(grad, dim=-1, keepdim=True).square_() | |
# row_mean.div_(grad.size(-1) + eps1) | |
# row_var.lerp_(row_mean, one_minus_beta2_t) | |
# col_mean = torch.norm(grad, dim=-2, keepdim=True).square_() | |
# col_mean.div_(grad.size(-2) + eps1) | |
# col_var.lerp_(col_mean, one_minus_beta2_t) | |
# var_estimate = row_var @ col_var | |
# max_row_var = row_var.max(dim=-2, keepdim=True)[0] | |
# var_estimate.div_(max_row_var.clamp_(min=eps1)) | |
# else: | |
# vi.mul_(group["gamma"]).add_(grad.square_(), alpha=1 - group["gamma"]) | |
# var_estimate = vi | |
# update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad) | |
# inf_norm = torch.norm(update, float('inf')) | |
# if inf_norm > 0: | |
# update.div_(inf_norm.clamp_(min=eps1)) | |
# denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"])) | |
# if param.dim() > 1: | |
# max_vals = update.abs().max(dim=-1, keepdim=True)[0] | |
# param.add_(-alpha / denom * update.sign() * max_vals) | |
# else: | |
# param.add_(-alpha / denom * update) | |
# state["step"] = step_t | |
# return loss | |