zzc0208's picture
Upload 265 files
f1f9265 verified
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Callable, Optional, Tuple
import numpy as np
import torch
from came_pytorch import CAME
from mmcv import Config
from mmcv.runner import OPTIMIZER_BUILDERS, OPTIMIZERS, DefaultOptimizerConstructor
from mmcv.runner import build_optimizer as mm_build_optimizer
from mmcv.utils import _BatchNorm, _InstanceNorm
from torch.nn import GroupNorm, LayerNorm
from torch.optim.optimizer import Optimizer
from .logger import get_root_logger
def auto_scale_lr(effective_bs, optimizer_cfg, rule="linear", base_batch_size=256):
assert rule in ["linear", "sqrt"]
logger = get_root_logger()
# scale by world size
if rule == "sqrt":
scale_ratio = math.sqrt(effective_bs / base_batch_size)
elif rule == "linear":
scale_ratio = effective_bs / base_batch_size
optimizer_cfg["lr"] *= scale_ratio
logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).')
return scale_ratio
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
def add_params(self, params, module, prefix="", is_dcn_module=None):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get("custom_keys", {})
# first sort with alphabet order and then sort with reversed len of str
# sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", 1.0)
bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", 1.0)
norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", 1.0)
bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
for name, param in module.named_parameters(recurse=False):
base_lr = self.base_lr
if name == "bias" and not (is_norm or is_dcn_module):
base_lr *= bias_lr_mult
# apply weight decay policies
base_wd = self.base_wd
if self.base_wd is not None:
# norm decay
if is_norm:
base_wd *= norm_decay_mult
# bias lr and decay
elif name == "bias" and not is_dcn_module:
# TODO: current bias_decay_mult will have affect on DCN
base_wd *= bias_decay_mult
param_group = {"params": [param]}
if not param.requires_grad:
param_group["requires_grad"] = False
params.append(param_group)
continue
if bypass_duplicate and self._is_in(param_group, params):
logger = get_root_logger()
logger.warn(f"{prefix} is duplicate. It is skipped since " f"bypass_duplicate={bypass_duplicate}")
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in custom_keys:
if isinstance(key, tuple):
scope, key_name = key
else:
scope, key_name = None, key
if scope is not None and scope not in f"{prefix}":
continue
if key_name in f"{prefix}.{name}":
is_custom = True
if "lr_mult" in custom_keys[key]:
# if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
# param_group['lr'] = self.base_lr
# else:
param_group["lr"] = self.base_lr * custom_keys[key]["lr_mult"]
elif "lr" not in param_group:
param_group["lr"] = base_lr
if self.base_wd is not None:
if "decay_mult" in custom_keys[key]:
param_group["weight_decay"] = self.base_wd * custom_keys[key]["decay_mult"]
elif "weight_decay" not in param_group:
param_group["weight_decay"] = base_wd
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if base_lr != self.base_lr:
param_group["lr"] = base_lr
if base_wd != self.base_wd:
param_group["weight_decay"] = base_wd
params.append(param_group)
for child_name, child_mod in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module)
def build_optimizer(model, optimizer_cfg):
# default parameter-wise config
logger = get_root_logger()
if hasattr(model, "module"):
model = model.module
# set optimizer constructor
optimizer_cfg.setdefault("constructor", "MyOptimizerConstructor")
# parameter-wise setting: cancel weight decay for some specific modules
custom_keys = dict()
for name, module in model.named_modules():
if hasattr(module, "zero_weight_decay"):
custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay})
paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
given_cfg = optimizer_cfg.get("paramwise_cfg")
if given_cfg:
paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
optimizer_cfg["paramwise_cfg"] = paramwise_cfg.cfg
# build optimizer
optimizer = mm_build_optimizer(model, optimizer_cfg)
weight_decay_groups = dict()
lr_groups = dict()
for group in optimizer.param_groups:
if not group.get("requires_grad", True):
continue
lr_groups.setdefault(group["lr"], []).append(group)
weight_decay_groups.setdefault(group["weight_decay"], []).append(group)
learnable_count, fix_count = 0, 0
for p in model.parameters():
if p.requires_grad:
learnable_count += 1
else:
fix_count += 1
fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
lr_info = "Lr group: " + ", ".join([f"{len(group)} params with lr {lr:.5f}" for lr, group in lr_groups.items()])
wd_info = "Weight decay group: " + ", ".join(
[f"{len(group)} params with weight decay {wd}" for wd, group in weight_decay_groups.items()]
)
opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
logger.info(opt_info)
return optimizer
@OPTIMIZERS.register_module()
class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
@staticmethod
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay
p.data.mul_(1 - lr * wd)
# weight update
update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
p.add_(update, alpha=-lr)
# decay the momentum running average coefficient
exp_avg.lerp_(grad, 1 - beta2)
@staticmethod
def exists(val):
return val is not None
@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
loss = None
if self.exists(closure):
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in filter(lambda p: self.exists(p.grad), group["params"]):
grad, lr, wd, beta1, beta2, state = (
p.grad,
group["lr"],
group["weight_decay"],
*group["betas"],
self.state[p],
)
# init state - exponential moving average of gradient values
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
exp_avg = state["exp_avg"]
self.update_fn(p, grad, exp_avg, lr, wd, beta1, beta2)
return loss
@OPTIMIZERS.register_module()
class CAMEWrapper(torch.optim.Optimizer):
"""Implements CAME algorithm.
This implementation is based on:
`CAME: Confidence-guided Adaptive Memory Efficient Optimization`
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): external learning rate (default: None)
eps (tuple[float, float]): regularization constants for square gradient
and instability respectively (default: (1e-30, 1e-16))
clip_threshold (float): threshold of root-mean-square of
final gradient update (default: 1.0)
betas (tuple[float, float, float]): coefficient used for computing running averages of
update, square gradient and instability (default: (0.9, 0.999, 0.9999)))
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-16),
clip_threshold=1.0,
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
defaults = dict(
lr=lr,
eps=eps,
clip_threshold=clip_threshold,
betas=betas,
weight_decay=weight_decay,
)
super().__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return False
def _get_options(self, param_shape):
if len(param_shape) == 4: # Conv layer
if param_shape[2] == 1 and param_shape[3] == 1: # 1x1 conv
return True, "1x1_conv"
else: # 3x3 conv or others
return False, "conv"
elif len(param_shape) == 2: # Linear layer, exactly 2D
return True, "linear"
return False, "other"
def _rms(self, tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("CAME does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
# factored = self._get_options(grad_shape)
factored, layer_type = self._get_options(grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(grad)
if factored:
if layer_type == "1x1_conv" or layer_type == "linear":
# 1x1 conv and linear layers can be handled the same way
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad)
state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad)
state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
state["step"] += 1
state["RMS"] = self._rms(p.data)
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
if layer_type == "1x1_conv" or layer_type == "linear":
# Handle dimensions
if len(grad_shape) == 4: # 1x1 conv
update_reshaped = update.squeeze(-1).squeeze(-1) # Remove last two dimensions
else:
update_reshaped = update
exp_avg_sq_row.mul_(group["betas"][1]).add_(
update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1]
)
exp_avg_sq_col.mul_(group["betas"][1]).add_(
update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1]
)
# Approximate calculation
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
if layer_type == "1x1_conv":
# Need to reshape back to 4D
update = update.view(grad_shape[0], grad_shape[1], 1, 1)
update.mul_(grad)
else:
# 3x3 conv or other cases: use standard AdamW approach
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
exp_avg = state["exp_avg"]
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])
# Confidence-guided strategy
# Calculation of instability
res = (update - exp_avg) ** 2 + group["eps"][1]
if factored:
exp_avg_res_row = state["exp_avg_res_row"]
exp_avg_res_col = state["exp_avg_res_col"]
if layer_type == "1x1_conv" or layer_type == "linear":
# Handle dimensions
if len(grad_shape) == 4: # 1x1 conv
res_reshaped = res.squeeze(-1).squeeze(-1) # Remove last two dimensions
else:
res_reshaped = res
# Update residual statistics
exp_avg_res_row.mul_(group["betas"][2]).add_(
res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2]
)
exp_avg_res_col.mul_(group["betas"][2]).add_(
res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2]
)
# Approximate calculation
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)
if layer_type == "1x1_conv":
# 需要reshape回4D
res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1)
update = res_approx.mul_(exp_avg)
else:
update = exp_avg.clone()
if group["weight_decay"] != 0:
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
update.mul_(group["lr"])
p.data.add_(-update)
return loss
@OPTIMIZERS.register_module()
class CAME8BitWrapper(torch.optim.Optimizer):
"""Implements 8bit-CAME algorithm.
Args:
params (iterable): parameters to optimize or dicts defining parameter groups
lr (float, optional): external learning rate (default: None)
eps (tuple[float, float]): regularization constants for square gradient
and instability respectively (default: (1e-30, 1e-16))
clip_threshold (float): threshold of root-mean-square of
final gradient update (default: 1.0)
betas (tuple[float, float, float]): coefficient used for computing running averages of
update, square gradient and instability (default: (0.9, 0.999, 0.9999)))
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
block_size (int): quantization block size, larger memory efficiency, but may reduce accuracy
min_8bit_size (int): minimum parameter size for using 8bit quantization, only layers larger than this value will be quantized
Note:
1. Only use 8bit quantization for large Linear layers and 1x1 Conv layers
2. Keep all statistics (exp_avg_sq_row, etc.) in 32bit to ensure stability
3. Use simple min-max quantization strategy, quantize each block separately
"""
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-16),
clip_threshold=1.0,
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
block_size=2048,
min_8bit_size=16384,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
logger = get_root_logger()
logger.info(f"Initializing CAME8bit with block_size={block_size}, min_8bit_size={min_8bit_size}")
defaults = dict(
lr=lr,
eps=eps,
clip_threshold=clip_threshold,
betas=betas,
weight_decay=weight_decay,
block_size=block_size,
min_8bit_size=min_8bit_size,
)
super().__init__(params, defaults)
def print_layer_info(self, param_shape, use_8bit):
"""Print layer information, including parameter size and whether 8bit quantization is used
Args:
param_shape (tuple): parameter shape
use_8bit (bool): whether 8bit quantization is used
"""
size = np.prod(param_shape)
layer_type = "unknown"
if len(param_shape) == 1:
layer_type = "1D Layer"
elif len(param_shape) == 2:
layer_type = "Linear"
elif len(param_shape) == 4:
if param_shape[2] == 1 and param_shape[3] == 1:
layer_type = "1x1 Conv"
else:
layer_type = "Conv"
status = "8bit" if use_8bit else "32bit"
print(f"{layer_type} layer with shape {param_shape}: {size:,} params -> using {status}")
def _should_use_8bit(self, param_shape):
"""Determine if a parameter should be quantized to 8bit
Rules:
1. linear layers: parameter size > min_8bit_size
2. 1x1 conv layers: parameter size > min_8bit_size
3. other layers: use 32bit
"""
if len(param_shape) == 2: # linear layer
return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"]
elif len(param_shape) == 4 and param_shape[2] == 1 and param_shape[3] == 1:
return param_shape[0] * param_shape[1] > self.defaults["min_8bit_size"]
return False # other layers are not quantized
def _quantize_state(self, state_tensor, block_size=2048):
"""Quantize a state tensor to 8bit
Args:
state_tensor: tensor to be quantized
block_size: quantization block size
Returns:
list of quantized data blocks, each block contains:
- data: uint8 data
- scale: quantization scale
- min: minimum value
"""
if state_tensor.numel() <= 1:
return state_tensor
quantized_chunks = []
for chunk in state_tensor.split(block_size):
# Calculate quantization parameters
chunk_min = chunk.min()
chunk_max = chunk.max()
scale = (chunk_max - chunk_min) / 255
# Quantize to 0-255 range
quantized_chunk = ((chunk - chunk_min) / scale).round().byte()
quantized_chunks.append({"data": quantized_chunk, "scale": scale, "min": chunk_min})
return quantized_chunks
def _dequantize_state(self, quantized_chunks):
"""Dequantize 8bit quantized data to 32bit float
Args:
quantized_chunks: list of quantized data blocks
Returns:
dequantized 32bit float tensor
"""
if not isinstance(quantized_chunks, list):
return quantized_chunks
chunks = []
for chunk_dict in quantized_chunks:
# Dequantize: value = data * scale + min
chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"]
chunks.append(chunk)
return torch.cat(chunks)
def _dequantize_state_first_step(self, quantized_chunks):
"""Efficient dequantization for the first step"""
if not isinstance(quantized_chunks, list):
return quantized_chunks
# 1. Dequantize all chunks to CPU
dequantized_chunks = []
for chunk_dict in quantized_chunks:
chunk = chunk_dict["data"].float() * chunk_dict["scale"] + chunk_dict["min"]
dequantized_chunks.append(chunk)
del chunk_dict["data"]
torch.cuda.empty_cache()
# 2. Concatenate all chunks
result = torch.cat(dequantized_chunks)
del dequantized_chunks
torch.cuda.empty_cache()
return result
def _get_options(self, param_shape):
if len(param_shape) == 4:
if param_shape[2] == 1 and param_shape[3] == 1:
return True, "1x1_conv"
else:
return False, "conv"
elif len(param_shape) == 2:
return True, "linear"
return False, "other"
def _rms(self, tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def step(self, closure=None):
"""Perform a single optimization step
Main steps:
1. Determine if 8bit quantization is needed
2. Update first and second moment estimates
3. Compute update step
4. Apply confidence-guided strategy
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("CAME8bit does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored, layer_type = self._get_options(grad_shape)
# Determine if 8bit quantization is used
use_8bit = self._should_use_8bit(grad_shape)
# State Initialization
if len(state) == 0:
self.print_layer_info(grad_shape, use_8bit)
state["step"] = 0
# Only use 8bit quantization for large matrices
if use_8bit:
state["exp_avg"] = self._quantize_state(torch.zeros_like(grad), group["block_size"])
else:
state["exp_avg"] = torch.zeros_like(grad)
if factored:
if layer_type == "1x1_conv" or layer_type == "linear":
# Keep row and column statistics in 32bit
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0]).type_as(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).type_as(grad)
state["exp_avg_res_row"] = torch.zeros(grad_shape[0]).type_as(grad)
state["exp_avg_res_col"] = torch.zeros(grad_shape[1]).type_as(grad)
else:
if use_8bit:
state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"])
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
else:
if use_8bit:
state["exp_avg_sq"] = self._quantize_state(torch.zeros_like(grad), group["block_size"])
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
state["step"] += 1
state["RMS"] = self._rms(p.data)
exp_avg = self._dequantize_state(state["exp_avg"]) if use_8bit else state["exp_avg"]
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"] # 32bit
exp_avg_sq_col = state["exp_avg_sq_col"] # 32bit
if layer_type == "1x1_conv" or layer_type == "linear":
if len(grad_shape) == 4:
update_reshaped = update.squeeze(-1).squeeze(-1)
else:
update_reshaped = update
# Update row and column statistics
exp_avg_sq_row.mul_(group["betas"][1]).add_(
update_reshaped.mean(dim=1), alpha=1.0 - group["betas"][1]
)
exp_avg_sq_col.mul_(group["betas"][1]).add_(
update_reshaped.mean(dim=0), alpha=1.0 - group["betas"][1]
)
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
if layer_type == "1x1_conv":
update = update.view(grad_shape[0], grad_shape[1], 1, 1)
update.mul_(grad)
else:
exp_avg_sq = self._dequantize_state(state["exp_avg_sq"]) if use_8bit else state["exp_avg_sq"]
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])
if use_8bit:
state["exp_avg_sq"] = self._quantize_state(exp_avg_sq, group["block_size"])
else:
state["exp_avg_sq"] = exp_avg_sq
update = exp_avg_sq.rsqrt().mul_(grad)
# Gradient clipping
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
# Update first moment
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])
# Re-quantize (if needed)
if use_8bit:
state["exp_avg"] = self._quantize_state(exp_avg, group["block_size"])
else:
state["exp_avg"] = exp_avg
# Confidence-guided strategy
res = (update - exp_avg) ** 2 + group["eps"][1]
if factored:
exp_avg_res_row = state["exp_avg_res_row"] # 32bit
exp_avg_res_col = state["exp_avg_res_col"] # 32bit
if layer_type == "1x1_conv" or layer_type == "linear":
if len(grad_shape) == 4:
res_reshaped = res.squeeze(-1).squeeze(-1)
else:
res_reshaped = res
# Update residual statistics
exp_avg_res_row.mul_(group["betas"][2]).add_(
res_reshaped.mean(dim=1), alpha=1.0 - group["betas"][2]
)
exp_avg_res_col.mul_(group["betas"][2]).add_(
res_reshaped.mean(dim=0), alpha=1.0 - group["betas"][2]
)
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)
if layer_type == "1x1_conv":
res_approx = res_approx.view(grad_shape[0], grad_shape[1], 1, 1)
update = res_approx.mul_(exp_avg)
else:
update = exp_avg.clone()
# Weight decay
if group["weight_decay"] != 0:
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
# Apply update
update.mul_(group["lr"])
p.data.add_(-update)
return loss
def load_state_dict(self, state_dict):
"""Load state dict and convert relevant states to 8bit"""
super().load_state_dict(state_dict)
for state in self.state.values():
for key in [
"exp_avg",
"exp_avg_sq",
"exp_avg_sq_row",
"exp_avg_sq_col",
"exp_avg_res_row",
"exp_avg_res_col",
]:
if key in state:
if isinstance(state[key], list):
state[key] = [
{
"data": exp["data"].byte(), # Convert data to 8bit directly
"scale": exp["scale"], # Keep scale unchanged
"min": exp["min"], # Keep min unchanged
}
for exp in state[key]
]
elif isinstance(state[key], torch.Tensor):
# If tensor, keep as 32bit
state[key] = state[key].float() # Ensure 32bit
del state_dict
torch.cuda.empty_cache()