Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Optional, TypeVar, Union, overload | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor, device, dtype, nn | |
import bitsandbytes as bnb | |
import bitsandbytes.functional | |
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds | |
from bitsandbytes.optim import GlobalOptimManager | |
from bitsandbytes.utils import OutlierTracer, find_outlier_dims | |
T = TypeVar("T", bound="torch.nn.Module") | |
class StableEmbedding(torch.nn.Embedding): | |
def __init__( | |
self, | |
num_embeddings: int, | |
embedding_dim: int, | |
padding_idx: Optional[int] = None, | |
max_norm: Optional[float] = None, | |
norm_type: float = 2.0, | |
scale_grad_by_freq: bool = False, | |
sparse: bool = False, | |
_weight: Optional[Tensor] = None, | |
device=None, | |
dtype=None, | |
) -> None: | |
super().__init__( | |
num_embeddings, | |
embedding_dim, | |
padding_idx, | |
max_norm, | |
norm_type, | |
scale_grad_by_freq, | |
sparse, | |
_weight, | |
device, | |
dtype, | |
) | |
self.norm = torch.nn.LayerNorm(embedding_dim, device=device) | |
GlobalOptimManager.get_instance().register_module_override( | |
self, "weight", {"optim_bits": 32} | |
) | |
def reset_parameters(self) -> None: | |
torch.nn.init.xavier_uniform_(self.weight) | |
self._fill_padding_idx_with_zero() | |
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding | |
to make the Layer compatible with Pytorch < 1.9. | |
This means that if this changes in future PyTorch releases this need to change too | |
which is cumbersome. However, with this we can ensure compatibility with previous | |
PyTorch releases. | |
""" | |
def _fill_padding_idx_with_zero(self) -> None: | |
if self.padding_idx is not None: | |
with torch.no_grad(): | |
self.weight[self.padding_idx].fill_(0) | |
def forward(self, input: Tensor) -> Tensor: | |
emb = F.embedding( | |
input, | |
self.weight, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
) | |
# always apply layer norm in full precision | |
emb = emb.to(torch.get_default_dtype()) | |
return self.norm(emb).to(self.weight.dtype) | |
class Embedding(torch.nn.Embedding): | |
def __init__( | |
self, | |
num_embeddings: int, | |
embedding_dim: int, | |
padding_idx: Optional[int] = None, | |
max_norm: Optional[float] = None, | |
norm_type: float = 2.0, | |
scale_grad_by_freq: bool = False, | |
sparse: bool = False, | |
_weight: Optional[Tensor] = None, | |
device: Optional[device] = None, | |
) -> None: | |
super().__init__( | |
num_embeddings, | |
embedding_dim, | |
padding_idx, | |
max_norm, | |
norm_type, | |
scale_grad_by_freq, | |
sparse, | |
_weight, | |
device=device | |
) | |
GlobalOptimManager.get_instance().register_module_override( | |
self, "weight", {"optim_bits": 32} | |
) | |
def reset_parameters(self) -> None: | |
torch.nn.init.xavier_uniform_(self.weight) | |
self._fill_padding_idx_with_zero() | |
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding | |
to make the Layer compatible with Pytorch < 1.9. | |
This means that if this changes in future PyTorch releases this need to change too | |
which is cumbersome. However, with this we can ensure compatibility with previous | |
PyTorch releases. | |
""" | |
def _fill_padding_idx_with_zero(self) -> None: | |
if self.padding_idx is not None: | |
with torch.no_grad(): | |
self.weight[self.padding_idx].fill_(0) | |
def forward(self, input: Tensor) -> Tensor: | |
emb = F.embedding( | |
input, | |
self.weight, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
) | |
return emb | |
class Params4bit(torch.nn.Parameter): | |
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): | |
if data is None: | |
data = torch.empty(0) | |
self = torch.Tensor._make_subclass(cls, data, requires_grad) | |
self.blocksize = blocksize | |
self.compress_statistics = compress_statistics | |
self.quant_type = quant_type | |
self.quant_state = quant_state | |
self.data = data | |
return self | |
def cuda(self, device): | |
w = self.data.contiguous().half().cuda(device) | |
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) | |
self.data = w_4bit | |
self.quant_state = quant_state | |
return self | |
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: | |
... | |
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: | |
... | |
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: | |
... | |
def to(self, *args, **kwargs): | |
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) | |
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): | |
return self.cuda(device) | |
else: | |
s = self.quant_state | |
if s is not None: | |
# make sure the quantization state is on the right device | |
s[0] = s[0].to(device) | |
if self.compress_statistics: | |
# TODO: refactor this. This is a nightmare | |
# for 4-bit: | |
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] | |
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] | |
#s[-2][0] = s[-2][0].to(device) # offset | |
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax | |
# for 8-bit | |
s[-3][0] = s[-3][0].to(device) # offset | |
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics | |
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook | |
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), | |
requires_grad=self.requires_grad, quant_state=self.quant_state, | |
blocksize=self.blocksize, compress_statistics=self.compress_statistics, | |
quant_type=self.quant_type) | |
return new_param | |
class Linear4bit(nn.Linear): | |
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None): | |
super().__init__(input_features, output_features, bias, device) | |
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) | |
self.compute_dtype = compute_dtype | |
def forward(self, x: torch.Tensor): | |
# weights are cast automatically as Int8Params, but the bias has to be cast manually | |
if self.bias is not None and self.bias.dtype != x.dtype: | |
self.bias.data = self.bias.data.to(x.dtype) | |
if getattr(self.weight, 'quant_state', None) is None: | |
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') | |
inp_dtype = x.dtype | |
if self.compute_dtype is not None: | |
x = x.to(self.compute_dtype) | |
bias = None if self.bias is None else self.bias.to(self.compute_dtype) | |
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) | |
out = out.to(inp_dtype) | |
return out | |
class LinearFP4(Linear4bit): | |
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None): | |
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device) | |
class LinearNF4(Linear4bit): | |
''' Implements the NF4 data type. | |
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that | |
is normalized into the range [-1, 1]. | |
For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) | |
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in | |
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. | |
''' | |
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None): | |
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) | |
class Int8Params(torch.nn.Parameter): | |
def __new__( | |
cls, | |
data=None, | |
requires_grad=True, | |
has_fp16_weights=False, | |
CB=None, | |
SCB=None, | |
): | |
cls.has_fp16_weights = has_fp16_weights | |
cls.CB = None | |
cls.SCB = None | |
if data is None: | |
data = torch.empty(0) | |
return torch.Tensor._make_subclass(cls, data, requires_grad) | |
def cuda(self, device): | |
if self.has_fp16_weights: | |
return super().cuda(device) | |
else: | |
# we store the 8-bit rows-major weight | |
# we convert this weight to the turning/ampere weight during the first inference pass | |
B = self.data.contiguous().half().cuda(device) | |
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) | |
del CBt | |
del SCBt | |
self.data = CB | |
setattr(self, "CB", CB) | |
setattr(self, "SCB", SCB) | |
return self | |
def to( | |
self: T, | |
device: Optional[Union[int, device]] = ..., | |
dtype: Optional[Union[dtype, str]] = ..., | |
non_blocking: bool = ..., | |
) -> T: | |
... | |
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: | |
... | |
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: | |
... | |
def to(self, *args, **kwargs): | |
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( | |
*args, **kwargs | |
) | |
if ( | |
device is not None | |
and device.type == "cuda" | |
and self.data.device.type == "cpu" | |
): | |
return self.cuda(device) | |
else: | |
new_param = Int8Params( | |
super().to( | |
device=device, dtype=dtype, non_blocking=non_blocking | |
), | |
requires_grad=self.requires_grad, | |
has_fp16_weights=self.has_fp16_weights, | |
) | |
new_param.CB = self.CB | |
new_param.SCB = self.SCB | |
return new_param | |
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
weight = state_dict.get(f"{prefix}weight") | |
if weight is None: | |
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing | |
return | |
weight_format = state_dict.pop(f"{prefix}weight_format", "row") | |
if weight_format != "row": | |
tile_indices = get_tile_inds(weight_format, weight.device) | |
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) | |
class Linear8bitLt(nn.Linear): | |
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, | |
memory_efficient_backward=False, threshold=0.0, index=None, device=None): | |
super().__init__(input_features, output_features, bias, device) | |
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" | |
self.state = bnb.MatmulLtState() | |
self.index = index | |
self.state.threshold = threshold | |
self.state.has_fp16_weights = has_fp16_weights | |
self.state.memory_efficient_backward = memory_efficient_backward | |
if threshold > 0.0 and not has_fp16_weights: | |
self.state.use_pool = True | |
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) | |
self._register_load_state_dict_pre_hook(maybe_rearrange_weight) | |
def _save_to_state_dict(self, destination, prefix, keep_vars): | |
super()._save_to_state_dict(destination, prefix, keep_vars) | |
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data | |
scb_name = "SCB" | |
# case 1: .cuda was called, SCB is in self.weight | |
param_from_weight = getattr(self.weight, scb_name) | |
# case 2: self.init_8bit_state was called, SCB is in self.state | |
param_from_state = getattr(self.state, scb_name) | |
# case 3: SCB is in self.state, weight layout reordered after first forward() | |
layout_reordered = self.state.CxB is not None | |
key_name = prefix + f"{scb_name}" | |
format_name = prefix + "weight_format" | |
if not self.state.has_fp16_weights: | |
if param_from_weight is not None: | |
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() | |
destination[format_name] = "row" | |
elif param_from_state is not None and not layout_reordered: | |
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() | |
destination[format_name] = "row" | |
elif param_from_state is not None: | |
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() | |
destination[format_name] = self.state.formatB | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, | |
error_msgs) | |
unexpected_copy = list(unexpected_keys) | |
for key in unexpected_copy: | |
input_name = key[len(prefix):] | |
if input_name == "SCB": | |
if self.weight.SCB is None: | |
# buffers not yet initialized, can't access them directly without quantizing first | |
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " | |
"not supported. Please call module.cuda() before module.load_state_dict()") | |
input_param = state_dict[key] | |
self.weight.SCB.copy_(input_param) | |
if self.state.SCB is not None: | |
self.state.SCB = self.weight.SCB | |
unexpected_keys.remove(key) | |
def init_8bit_state(self): | |
self.state.CB = self.weight.CB | |
self.state.SCB = self.weight.SCB | |
self.weight.CB = None | |
self.weight.SCB = None | |
def forward(self, x: torch.Tensor): | |
self.state.is_training = self.training | |
if self.weight.CB is not None: | |
self.init_8bit_state() | |
# weights are cast automatically as Int8Params, but the bias has to be cast manually | |
if self.bias is not None and self.bias.dtype != x.dtype: | |
self.bias.data = self.bias.data.to(x.dtype) | |
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) | |
if not self.state.has_fp16_weights: | |
if self.state.CB is not None and self.state.CxB is not None: | |
# we converted 8-bit row major to turing/ampere format in the first inference pass | |
# we no longer need the row-major weight | |
del self.state.CB | |
self.weight.data = self.state.CxB | |
return out | |
class OutlierAwareLinear(nn.Linear): | |
def __init__(self, input_features, output_features, bias=True, device=None): | |
super().__init__(input_features, output_features, bias, device) | |
self.outlier_dim = None | |
self.is_quantized = False | |
def forward_with_outliers(self, x, outlier_idx): | |
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') | |
def quantize_weight(self, w, outlier_idx): | |
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') | |
def forward(self, x): | |
if self.outlier_dim is None: | |
tracer = OutlierTracer.get_instance() | |
if not tracer.is_initialized(): | |
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') | |
outlier_idx = tracer.get_outliers(self.weight) | |
#print(outlier_idx, tracer.get_hvalue(self.weight)) | |
self.outlier_dim = outlier_idx | |
if not self.is_quantized: | |
w = self.quantize_weight(self.weight, self.outlier_dim) | |
self.weight.data.copy_(w) | |
self.is_quantized = True | |
class SwitchBackLinearBnb(nn.Linear): | |
def __init__( | |
self, | |
input_features, | |
output_features, | |
bias=True, | |
has_fp16_weights=True, | |
memory_efficient_backward=False, | |
threshold=0.0, | |
index=None, | |
device=None | |
): | |
super().__init__( | |
input_features, output_features, bias, device | |
) | |
self.state = bnb.MatmulLtState() | |
self.index = index | |
self.state.threshold = threshold | |
self.state.has_fp16_weights = has_fp16_weights | |
self.state.memory_efficient_backward = memory_efficient_backward | |
if threshold > 0.0 and not has_fp16_weights: | |
self.state.use_pool = True | |
self.weight = Int8Params( | |
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights | |
) | |
def init_8bit_state(self): | |
self.state.CB = self.weight.CB | |
self.state.SCB = self.weight.SCB | |
self.weight.CB = None | |
self.weight.SCB = None | |
def forward(self, x): | |
self.state.is_training = self.training | |
if self.weight.CB is not None: | |
self.init_8bit_state() | |
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias | |