Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import abc as container_abcs | |
from collections import defaultdict | |
from copy import deepcopy | |
from itertools import chain | |
import torch | |
import bitsandbytes.functional as F | |
class MockArgs: | |
def __init__(self, initial_data): | |
for key in initial_data: | |
setattr(self, key, initial_data[key]) | |
class GlobalOptimManager: | |
_instance = None | |
def __init__(self): | |
raise RuntimeError("Call get_instance() instead") | |
def initialize(self): | |
self.pid2config = {} | |
self.index2config = {} | |
self.optimizer = None | |
self.uses_config_override = False | |
self.module_weight_config_triple = [] | |
def get_instance(cls): | |
if cls._instance is None: | |
cls._instance = cls.__new__(cls) | |
cls._instance.initialize() | |
return cls._instance | |
def register_parameters(self, params): | |
param_groups = list(params) | |
if not isinstance(param_groups[0], dict): | |
param_groups = [{"params": param_groups}] | |
for group_index, group in enumerate(param_groups): | |
for p_index, p in enumerate(group["params"]): | |
if id(p) in self.pid2config: | |
self.index2config[(group_index, p_index)] = self.pid2config[ | |
id(p) | |
] | |
def override_config( | |
self, parameters, key=None, value=None, key_value_dict=None | |
): | |
""" | |
Overrides initial optimizer config for specific parameters. | |
The key-values of the optimizer config for the input parameters are overridden | |
This can be both, optimizer parameters like "betas", or "lr" or it can be | |
8-bit specific parameters like "optim_bits", "percentile_clipping". | |
Parameters | |
---------- | |
parameters : torch.Tensor or list(torch.Tensors) | |
The input parameters. | |
key : str | |
The hyperparamter to override. | |
value : object | |
The value for the hyperparamters. | |
key_value_dict : dict | |
A dictionary with multiple key-values to override. | |
""" | |
self.uses_config_override = True | |
if isinstance(parameters, torch.nn.Parameter): | |
parameters = [parameters] | |
if isinstance(parameters, torch.Tensor): | |
parameters = [parameters] | |
if key is not None and value is not None: | |
assert key_value_dict is None | |
key_value_dict = {key: value} | |
if key_value_dict is not None: | |
for p in parameters: | |
if id(p) in self.pid2config: | |
self.pid2config[id(p)].update(key_value_dict) | |
else: | |
self.pid2config[id(p)] = key_value_dict | |
def register_module_override(self, module, param_name, config): | |
self.module_weight_config_triple.append((module, param_name, config)) | |
class Optimizer8bit(torch.optim.Optimizer): | |
def __init__(self, params, defaults, optim_bits=32, is_paged=False): | |
super().__init__(params, defaults) | |
self.initialized = False | |
self.name2qmap = {} | |
self.is_paged = is_paged | |
self.page_mng = F.GlobalPageManager.get_instance() | |
self.mng = GlobalOptimManager.get_instance() | |
self.non_castable_tensor_keys = { | |
"qmap1", | |
"qmap2", | |
"max1", | |
"max2", | |
"new_max1", | |
"new_max2", | |
"state1", | |
"state2", | |
"gnorm_vec", | |
"absmax1", | |
"absmax2", | |
"unorm_vec", | |
} | |
if optim_bits == 8: | |
self.fill_qmap() | |
def fill_qmap(self): | |
self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True) | |
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) | |
def __setstate__(self, state): | |
super().__setstate__(state) | |
def load_state_dict(self, state_dict): | |
r"""Loads the optimizer state. | |
Args: | |
state_dict (dict): optimizer state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
# deepcopy, to be consistent with module API | |
state_dict = deepcopy(state_dict) | |
# Validate the state_dict | |
groups = self.param_groups | |
saved_groups = state_dict["param_groups"] | |
if len(groups) != len(saved_groups): | |
raise ValueError( | |
"loaded state dict has a different number of " | |
"parameter groups" | |
) | |
param_lens = (len(g["params"]) for g in groups) | |
saved_lens = (len(g["params"]) for g in saved_groups) | |
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): | |
raise ValueError( | |
"loaded state dict contains a parameter group " | |
"that doesn't match the size of optimizer's group" | |
) | |
# Update the state | |
id_map = { | |
old_id: p | |
for old_id, p in zip( | |
chain.from_iterable(g["params"] for g in saved_groups), | |
chain.from_iterable(g["params"] for g in groups), | |
) | |
} | |
def cast(param, value): | |
r"""Make a deep copy of value, casting all tensors to device of param.""" | |
if isinstance(value, torch.Tensor): | |
# Floating-point types are a bit special here. They are the only ones | |
# that are assumed to always match the type of params. | |
if param.is_floating_point() and value.dtype != torch.uint8: | |
value = value.to(param.dtype) | |
return value | |
elif isinstance(value, dict): | |
for k, v in value.items(): | |
if k in self.non_castable_tensor_keys: | |
value[k] = v.to(param.device) | |
else: | |
value[k] = cast(param, v) | |
return value | |
elif isinstance(value, container_abcs.Iterable): | |
return type(value)(cast(param, v) for v in value) | |
else: | |
return value | |
# Copy state assigned to params (and cast tensors to appropriate types). | |
# State that is not assigned to params is copied as is (needed for | |
# backward compatibility). | |
state = defaultdict(dict) | |
for k, v in state_dict["state"].items(): | |
if k in id_map: | |
param = id_map[k] | |
state[param] = cast(param, v) | |
else: | |
state[k] = v | |
# Update parameter groups, setting their 'params' value | |
def update_group(group, new_group): | |
new_group["params"] = group["params"] | |
return new_group | |
param_groups = [ | |
update_group(g, ng) for g, ng in zip(groups, saved_groups) | |
] | |
self.__setstate__({"state": state, "param_groups": param_groups}) | |
def to_gpu(self): | |
for gindex, group in enumerate(self.param_groups): | |
for pindex, p in enumerate(group["params"]): | |
if p in self.state: | |
values = self.state[p] | |
for k, v in values.items(): | |
if isinstance(v, torch.Tensor): | |
is_paged = getattr(v, 'is_paged', False) | |
if not is_paged: | |
self.state[p][k] = v.to(p.device) | |
def check_overrides(self): | |
for module, attr, config in self.mng.module_weight_config_triple: | |
pmodule = getattr(module, attr) | |
assert pmodule is not None | |
assert isinstance(pmodule, torch.Tensor) or isinstance( | |
pmodule, torch.Parameter | |
) | |
found = False | |
for gindex, group in enumerate(self.param_groups): | |
if found: | |
break | |
for pindex, p in enumerate(group["params"]): | |
if found: | |
break | |
if id(p) == id(pmodule): | |
# found the matching parameter | |
# init override | |
self.mng.pid2config[id(p)] = config | |
self.mng.index2config[ | |
(gindex, pindex) | |
] = self.mng.pid2config[id(p)] | |
found = True | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
overflows = [] | |
if not self.initialized: | |
self.check_overrides() | |
self.to_gpu() # needed for fairseq pure fp16 training | |
self.initialized = True | |
#if self.is_paged: self.page_mng.prefetch_all() | |
for gindex, group in enumerate(self.param_groups): | |
for pindex, p in enumerate(group["params"]): | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self.init_state(group, p, gindex, pindex) | |
self.prefetch_state(p) | |
self.update_step(group, p, gindex, pindex) | |
torch.cuda.synchronize() | |
if self.is_paged: | |
# all paged operation are asynchronous, we need | |
# to sync to make sure all tensors are in the right state | |
torch.cuda.synchronize() | |
return loss | |
def get_config(self, gindex, pindex, group): | |
config = {} | |
config["betas"] = group["betas"] | |
config["eps"] = group["eps"] | |
config["weight_decay"] = group["weight_decay"] | |
config["lr"] = group["lr"] | |
config["optim_bits"] = self.args.optim_bits | |
config["min_8bit_size"] = self.args.min_8bit_size | |
config["percentile_clipping"] = self.args.percentile_clipping | |
config["block_wise"] = self.args.block_wise | |
config["max_unorm"] = self.args.max_unorm | |
config["skip_zeros"] = self.args.skip_zeros | |
if (gindex, pindex) in self.mng.index2config: | |
config.update(self.mng.index2config[(gindex, pindex)]) | |
return config | |
def init_state(self, group, p, gindex, pindex): | |
raise NotImplementedError("init_state method needs to be overridden") | |
def update_step(self, group, p, gindex, pindex): | |
raise NotImplementedError( | |
"The update_step method needs to be overridden" | |
) | |
def get_state_buffer(self, p, dtype=torch.float32): | |
if not self.is_paged or p.numel() < 1e5: | |
return torch.zeros_like(p, dtype=dtype, device=p.device) | |
else: | |
# > 1 MB | |
buff = F.get_paged(*p.shape, dtype=dtype, device=p.device) | |
F.fill(buff, 0) | |
self.page_mng.paged_tensors.append(buff) | |
return buff | |
def prefetch_state(self, p): | |
if self.is_paged: | |
state = self.state[p] | |
s1 = state['state1'] | |
is_paged = getattr(s1, 'is_paged', False) | |
if is_paged: | |
F.prefetch_tensor(state['state1']) | |
if 'state2' in state: | |
F.prefetch_tensor(state['state2']) | |
class Optimizer2State(Optimizer8bit): | |
def __init__( | |
self, | |
optimizer_name, | |
params, | |
lr=1e-3, | |
betas=(0.9, 0.999), | |
eps=1e-8, | |
weight_decay=0.0, | |
optim_bits=32, | |
args=None, | |
min_8bit_size=4096, | |
percentile_clipping=100, | |
block_wise=True, | |
max_unorm=0.0, | |
skip_zeros=False, | |
is_paged=False | |
): | |
if not 0.0 <= lr: | |
raise ValueError(f"Invalid learning rate: {lr}") | |
if not 0.0 <= eps: | |
raise ValueError(f"Invalid epsilon value: {eps}") | |
if isinstance(betas, str): | |
# format: '(beta1, beta2)' | |
betas = betas.replace("(", "").replace(")", "").strip().split(",") | |
betas = [float(b) for b in betas] | |
for i in range(len(betas)): | |
if not 0.0 <= betas[i] < 1.0: | |
raise ValueError( | |
f"Invalid beta parameter at index {i}: {betas[i]}" | |
) | |
if not 0.0 <= weight_decay: | |
raise ValueError( | |
f"Invalid weight_decay value: {weight_decay}" | |
) | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
super().__init__(params, defaults, optim_bits, is_paged) | |
if args is None: | |
args = {} | |
args["optim_bits"] = optim_bits | |
args["percentile_clipping"] = 100 | |
args["min_8bit_size"] = min_8bit_size | |
args["percentile_clipping"] = percentile_clipping | |
args["block_wise"] = block_wise | |
args["max_unorm"] = max_unorm | |
args["skip_zeros"] = skip_zeros | |
self.args = MockArgs(args) | |
else: | |
self.args = args | |
self.optimizer_name = optimizer_name | |
def init_state(self, group, p, gindex, pindex): | |
config = self.get_config(gindex, pindex, group) | |
if config["optim_bits"] == 32: | |
dtype = torch.float32 | |
elif config["optim_bits"] == 8: | |
dtype = torch.uint8 | |
else: | |
raise NotImplementedError( | |
f'Amount of optimizer bits not supported: {config["optim_bits"]}' | |
) | |
if p.numel() < config["min_8bit_size"]: | |
dtype = torch.float32 | |
state = self.state[p] | |
state["step"] = 0 | |
if dtype == torch.float32 or ( | |
dtype == torch.uint8 and p.numel() < 4096 | |
): | |
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) | |
state["state2"] = self.get_state_buffer(p, dtype=torch.float32) | |
elif dtype == torch.uint8: | |
if state["step"] == 0: | |
if "dynamic" not in self.name2qmap: | |
self.fill_qmap() | |
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( | |
p.device | |
) | |
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( | |
p.device | |
) | |
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) | |
state["qmap1"] = self.name2qmap["dynamic"] | |
state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) | |
state["qmap2"] = self.name2qmap["udynamic"] | |
if config["block_wise"]: | |
n = p.numel() | |
blocks = n // 2048 | |
blocks += 1 if n % 2048 > 0 else 0 | |
state["absmax1"] = torch.zeros( | |
(blocks,), dtype=torch.float32, device=p.device | |
) | |
state["absmax2"] = torch.zeros( | |
(blocks,), dtype=torch.float32, device=p.device | |
) | |
else: | |
state["max1"] = torch.zeros( | |
(1,), dtype=torch.float32, device=p.device | |
) | |
state["new_max1"] = torch.zeros( | |
(1,), dtype=torch.float32, device=p.device | |
) | |
state["max2"] = torch.zeros( | |
(1,), dtype=torch.float32, device=p.device | |
) | |
state["new_max2"] = torch.zeros( | |
(1,), dtype=torch.float32, device=p.device | |
) | |
if config["percentile_clipping"] < 100: | |
state["gnorm_vec"] = torch.zeros((100,), device=p.device) | |
if config["max_unorm"] > 0.0: | |
state["unorm_vec"] = torch.zeros((1,), device=p.device) | |
def update_step(self, group, p, gindex, pindex): | |
state = self.state[p] | |
grad = p.grad | |
config = self.get_config(gindex, pindex, group) | |
state["step"] += 1 | |
step = state["step"] | |
if config["percentile_clipping"] < 100: | |
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( | |
grad, state["gnorm_vec"], step, config["percentile_clipping"] | |
) | |
else: | |
gnorm_scale = 1.0 | |
if state["state1"].dtype == torch.float: | |
F.optimizer_update_32bit( | |
self.optimizer_name, | |
grad, | |
p, | |
state["state1"], | |
config["betas"][0], | |
config["eps"], | |
step, | |
config["lr"], | |
state["state2"], | |
config["betas"][1], | |
config["weight_decay"], | |
gnorm_scale, | |
state["unorm_vec"] if config["max_unorm"] > 0.0 else None, | |
max_unorm=config["max_unorm"], | |
skip_zeros=config["skip_zeros"], | |
) | |
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: | |
F.optimizer_update_8bit( | |
self.optimizer_name, | |
grad, | |
p, | |
state["state1"], | |
state["state2"], | |
config["betas"][0], | |
config["betas"][1], | |
config["eps"], | |
step, | |
config["lr"], | |
state["qmap1"], | |
state["qmap2"], | |
state["max1"], | |
state["max2"], | |
state["new_max1"], | |
state["new_max2"], | |
config["weight_decay"], | |
gnorm_scale=gnorm_scale, | |
unorm_vec=state["unorm_vec"] | |
if config["max_unorm"] > 0.0 | |
else None, | |
max_unorm=config["max_unorm"], | |
) | |
# swap maxes | |
state["max1"], state["new_max1"] = state["new_max1"], state["max1"] | |
state["max2"], state["new_max2"] = state["new_max2"], state["max2"] | |
elif state["state1"].dtype == torch.uint8 and config["block_wise"]: | |
F.optimizer_update_8bit_blockwise( | |
self.optimizer_name, | |
grad, | |
p, | |
state["state1"], | |
state["state2"], | |
config["betas"][0], | |
config["betas"][1], | |
config["eps"], | |
step, | |
config["lr"], | |
state["qmap1"], | |
state["qmap2"], | |
state["absmax1"], | |
state["absmax2"], | |
config["weight_decay"], | |
gnorm_scale=gnorm_scale, | |
skip_zeros=config["skip_zeros"], | |
) | |
class Optimizer1State(Optimizer8bit): | |
def __init__( | |
self, | |
optimizer_name, | |
params, | |
lr=1e-3, | |
betas=(0.9, 0.0), | |
eps=1e-8, | |
weight_decay=0.0, | |
optim_bits=32, | |
args=None, | |
min_8bit_size=4096, | |
percentile_clipping=100, | |
block_wise=True, | |
max_unorm=0.0, | |
skip_zeros=False, | |
is_paged=False | |
): | |
if not 0.0 <= lr: | |
raise ValueError(f"Invalid learning rate: {lr}") | |
if not 0.0 <= eps: | |
raise ValueError(f"Invalid epsilon value: {eps}") | |
for i in range(len(betas)): | |
if not 0.0 <= betas[i] < 1.0: | |
raise ValueError( | |
f"Invalid beta parameter at index {i}: {betas[i]}" | |
) | |
if not 0.0 <= weight_decay: | |
raise ValueError( | |
f"Invalid weight_decay value: {weight_decay}" | |
) | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
super().__init__(params, defaults, optim_bits, is_paged) | |
if args is None: | |
args = {} | |
args["optim_bits"] = optim_bits | |
args["percentile_clipping"] = 100 | |
args["min_8bit_size"] = min_8bit_size | |
args["percentile_clipping"] = percentile_clipping | |
args["block_wise"] = block_wise | |
args["max_unorm"] = max_unorm | |
args["skip_zeros"] = skip_zeros | |
self.args = MockArgs(args) | |
else: | |
self.args = args | |
self.optimizer_name = optimizer_name | |
def init_state(self, group, p, gindex, pindex): | |
config = self.get_config(gindex, pindex, group) | |
if config["optim_bits"] == 32: | |
dtype = torch.float32 | |
elif config["optim_bits"] == 8: | |
dtype = torch.uint8 | |
else: | |
raise NotImplementedError( | |
f'Amount of optimizer bits not supported: {config["optim_bits"]}' | |
) | |
if p.numel() < config["min_8bit_size"]: | |
dtype = torch.float32 | |
state = self.state[p] | |
state["step"] = 0 | |
if dtype == torch.float32 or ( | |
dtype == torch.uint8 and p.numel() < 4096 | |
): | |
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) | |
elif dtype == torch.uint8: | |
if state["step"] == 0: | |
if "dynamic" not in self.name2qmap: | |
self.fill_qmap() | |
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( | |
p.device | |
) | |
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) | |
state["qmap1"] = self.name2qmap["dynamic"] | |
if config["block_wise"]: | |
n = p.numel() | |
blocks = n // 2048 | |
blocks += 1 if n % 2048 > 0 else 0 | |
state["absmax1"] = torch.zeros( | |
(blocks,), dtype=torch.float32, device=p.device | |
) | |
else: | |
state["max1"] = torch.zeros( | |
(1,), dtype=torch.float32, device=p.device | |
) | |
state["new_max1"] = torch.zeros( | |
(1,), dtype=torch.float32, device=p.device | |
) | |
if config["percentile_clipping"] < 100: | |
state["gnorm_vec"] = torch.zeros((100,), device=p.device) | |
if config["max_unorm"] > 0.0: | |
state["unorm_vec"] = torch.zeros((1,), device=p.device) | |
def update_step(self, group, p, gindex, pindex): | |
state = self.state[p] | |
grad = p.grad | |
config = self.get_config(gindex, pindex, group) | |
state["step"] += 1 | |
step = state["step"] | |
if config["percentile_clipping"] < 100: | |
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( | |
grad, state["gnorm_vec"], step, config["percentile_clipping"] | |
) | |
else: | |
gnorm_scale = 1.0 | |
if state["state1"].dtype == torch.float: | |
F.optimizer_update_32bit( | |
self.optimizer_name, | |
grad, | |
p, | |
state["state1"], | |
config["betas"][0], | |
config["eps"], | |
step, | |
config["lr"], | |
None, | |
config['betas'][1], | |
config["weight_decay"], | |
gnorm_scale, | |
state["unorm_vec"] if config["max_unorm"] > 0.0 else None, | |
max_unorm=config["max_unorm"], | |
skip_zeros=config["skip_zeros"], | |
) | |
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: | |
F.optimizer_update_8bit( | |
self.optimizer_name, | |
grad, | |
p, | |
state["state1"], | |
None, | |
config["betas"][0], | |
config["betas"][1], | |
config["eps"], | |
step, | |
config["lr"], | |
state["qmap1"], | |
None, | |
state["max1"], | |
None, | |
state["new_max1"], | |
None, | |
config["weight_decay"], | |
gnorm_scale, | |
state["unorm_vec"] if config["max_unorm"] > 0.0 else None, | |
max_unorm=config["max_unorm"], | |
) | |
state["max1"], state["new_max1"] = state["new_max1"], state["max1"] | |
elif state["state1"].dtype == torch.uint8 and config["block_wise"]: | |
F.optimizer_update_8bit_blockwise( | |
self.optimizer_name, | |
grad, | |
p, | |
state["state1"], | |
None, | |
config["betas"][0], | |
config["betas"][1], | |
config["eps"], | |
step, | |
config["lr"], | |
state["qmap1"], | |
None, | |
state["absmax1"], | |
None, | |
config["weight_decay"], | |
gnorm_scale=gnorm_scale, | |
skip_zeros=config["skip_zeros"], | |
) | |