|
|
|
|
|
|
|
|
|
from typing import ( |
|
Any, |
|
Callable, |
|
Dict, |
|
Iterator, |
|
Mapping, |
|
Optional, |
|
Set, |
|
Tuple, |
|
TypeVar, |
|
Union, |
|
overload, |
|
) |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor, device, dtype, nn |
|
from torch.nn.parameter import Parameter |
|
|
|
import bitsandbytes as bnb |
|
from bitsandbytes.optim import GlobalOptimManager |
|
|
|
T = TypeVar("T", bound="torch.nn.Module") |
|
|
|
|
|
class StableEmbedding(torch.nn.Embedding): |
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: Optional[int] = None, |
|
max_norm: Optional[float] = None, |
|
norm_type: float = 2.0, |
|
scale_grad_by_freq: bool = False, |
|
sparse: bool = False, |
|
_weight: Optional[Tensor] = None, |
|
) -> None: |
|
super(StableEmbedding, self).__init__( |
|
num_embeddings, |
|
embedding_dim, |
|
padding_idx, |
|
max_norm, |
|
norm_type, |
|
scale_grad_by_freq, |
|
sparse, |
|
_weight, |
|
) |
|
self.norm = torch.nn.LayerNorm(embedding_dim) |
|
GlobalOptimManager.get_instance().register_module_override( |
|
self, "weight", {"optim_bits": 32} |
|
) |
|
|
|
def reset_parameters(self) -> None: |
|
torch.nn.init.xavier_uniform_(self.weight) |
|
self._fill_padding_idx_with_zero() |
|
|
|
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding |
|
to make the Layer compatible with Pytorch < 1.9. |
|
This means that if this changes in future PyTorch releases this need to change too |
|
which is cumbersome. However, with this we can ensure compatibility with previous |
|
PyTorch releases. |
|
""" |
|
|
|
def _fill_padding_idx_with_zero(self) -> None: |
|
if self.padding_idx is not None: |
|
with torch.no_grad(): |
|
self.weight[self.padding_idx].fill_(0) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
emb = F.embedding( |
|
input, |
|
self.weight, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
|
|
return self.norm(emb) |
|
|
|
|
|
class Embedding(torch.nn.Embedding): |
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: Optional[int] = None, |
|
max_norm: Optional[float] = None, |
|
norm_type: float = 2.0, |
|
scale_grad_by_freq: bool = False, |
|
sparse: bool = False, |
|
_weight: Optional[Tensor] = None, |
|
) -> None: |
|
super(Embedding, self).__init__( |
|
num_embeddings, |
|
embedding_dim, |
|
padding_idx, |
|
max_norm, |
|
norm_type, |
|
scale_grad_by_freq, |
|
sparse, |
|
_weight, |
|
) |
|
GlobalOptimManager.get_instance().register_module_override( |
|
self, "weight", {"optim_bits": 32} |
|
) |
|
|
|
def reset_parameters(self) -> None: |
|
torch.nn.init.xavier_uniform_(self.weight) |
|
self._fill_padding_idx_with_zero() |
|
|
|
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding |
|
to make the Layer compatible with Pytorch < 1.9. |
|
This means that if this changes in future PyTorch releases this need to change too |
|
which is cumbersome. However, with this we can ensure compatibility with previous |
|
PyTorch releases. |
|
""" |
|
|
|
def _fill_padding_idx_with_zero(self) -> None: |
|
if self.padding_idx is not None: |
|
with torch.no_grad(): |
|
self.weight[self.padding_idx].fill_(0) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
emb = F.embedding( |
|
input, |
|
self.weight, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
|
|
return emb |
|
|
|
|
|
class Int8Params(torch.nn.Parameter): |
|
def __new__( |
|
cls, |
|
data=None, |
|
requires_grad=True, |
|
has_fp16_weights=False, |
|
CB=None, |
|
SCB=None, |
|
): |
|
cls.has_fp16_weights = has_fp16_weights |
|
cls.CB = None |
|
cls.SCB = None |
|
if data is None: |
|
data = torch.empty(0) |
|
return torch.Tensor._make_subclass(cls, data, requires_grad) |
|
|
|
def cuda(self, device): |
|
if self.has_fp16_weights: |
|
return super().cuda(device) |
|
else: |
|
|
|
|
|
B = self.data.contiguous().half().cuda(device) |
|
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) |
|
del CBt |
|
del SCBt |
|
self.data = CB |
|
setattr(self, "CB", CB) |
|
setattr(self, "SCB", SCB) |
|
|
|
return self |
|
|
|
@overload |
|
def to( |
|
self: T, |
|
device: Optional[Union[int, device]] = ..., |
|
dtype: Optional[Union[dtype, str]] = ..., |
|
non_blocking: bool = ..., |
|
) -> T: |
|
... |
|
|
|
@overload |
|
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: |
|
... |
|
|
|
@overload |
|
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: |
|
... |
|
|
|
def to(self, *args, **kwargs): |
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( |
|
*args, **kwargs |
|
) |
|
|
|
if ( |
|
device is not None |
|
and device.type == "cuda" |
|
and self.data.device.type == "cpu" |
|
): |
|
return self.cuda(device) |
|
else: |
|
new_param = Int8Params( |
|
super().to( |
|
device=device, dtype=dtype, non_blocking=non_blocking |
|
), |
|
requires_grad=self.requires_grad, |
|
has_fp16_weights=self.has_fp16_weights, |
|
) |
|
new_param.CB = self.CB |
|
new_param.SCB = self.SCB |
|
|
|
return new_param |
|
|
|
|
|
class Linear8bitLt(nn.Linear): |
|
def __init__( |
|
self, |
|
input_features, |
|
output_features, |
|
bias=True, |
|
has_fp16_weights=True, |
|
memory_efficient_backward=False, |
|
threshold=0.0, |
|
index=None, |
|
): |
|
super(Linear8bitLt, self).__init__( |
|
input_features, output_features, bias |
|
) |
|
self.state = bnb.MatmulLtState() |
|
self.index = index |
|
|
|
self.state.threshold = threshold |
|
self.state.has_fp16_weights = has_fp16_weights |
|
self.state.memory_efficient_backward = memory_efficient_backward |
|
if threshold > 0.0 and not has_fp16_weights: |
|
self.state.use_pool = True |
|
|
|
self.weight = Int8Params( |
|
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights |
|
) |
|
|
|
def init_8bit_state(self): |
|
self.state.CB = self.weight.CB |
|
self.state.SCB = self.weight.SCB |
|
self.weight.CB = None |
|
self.weight.SCB = None |
|
|
|
def forward(self, x): |
|
self.state.is_training = self.training |
|
|
|
if self.weight.CB is not None: |
|
self.init_8bit_state() |
|
|
|
|
|
if self.bias is not None and self.bias.dtype != torch.float16: |
|
self.bias.data = self.bias.data.half() |
|
|
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) |
|
|
|
if not self.state.has_fp16_weights: |
|
if not self.state.memory_efficient_backward and self.state.CB is not None: |
|
|
|
|
|
del self.state.CB |
|
self.weight.data = self.state.CxB |
|
elif self.state.memory_efficient_backward and self.state.CxB is not None: |
|
|
|
|
|
del self.state.CxB |
|
|
|
return out |
|
|