Spaces:
Sleeping
Sleeping
Elite-text-gen-web
/
venv
/lib
/python3.10
/site-packages
/bitsandbytes
/nn
/triton_based_modules.py
import torch | |
import torch.nn as nn | |
import time | |
from functools import partial | |
from bitsandbytes.triton.triton_utils import is_triton_available | |
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise | |
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise | |
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose | |
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize | |
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose | |
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze | |
class _switchback_global(torch.autograd.Function): | |
def forward(ctx, X_3D, W, bias): | |
# reshape input to [N * L, D] | |
X = X_3D.view(-1, X_3D.size(-1)) | |
# rowwise quantize for X, global quantize for W | |
X_int8, state_X = quantize_rowwise(X) | |
W_int8, state_W = quantize_global(W) | |
# save for backward. | |
ctx.save_for_backward = X, W | |
# matmult, fused dequant and add bias | |
# call "mixed" because we are mixing rowwise quantized and global quantized | |
return int8_matmul_mixed_dequanitze( | |
X_int8, W_int8.t(), state_X, state_W, bias | |
).view(*X_3D.size()[:-1], -1) | |
def backward(ctx, G_3D): | |
# reshape input to [N_out * L, D] | |
G = G_3D.reshape(-1, G_3D.size(-1)) | |
grad_X = grad_W = grad_bias = None | |
X, W = ctx.save_for_backward | |
if ctx.needs_input_grad[0]: | |
# rowwise quantize for G, global quantize for W | |
# for W, we also fuse the transpose operation because only A @ B^T is supported | |
# so we transpose once then call .t() in the matmul | |
G_int8, state_G = quantize_rowwise(G) | |
W_int8, state_W = quantize_global_transpose(W) | |
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( | |
*G_3D.size()[:-1], -1 | |
) | |
if ctx.needs_input_grad[1]: | |
# backward pass uses standard weight grad | |
grad_W = torch.matmul(G.t(), X.to(G.dtype)) | |
if ctx.needs_input_grad[2]: | |
grad_bias = G.sum(dim=0) | |
return grad_X, grad_W, grad_bias | |
class _switchback_vectorrize(torch.autograd.Function): | |
def forward(ctx, X_3D, W, bias): | |
# reshape input to [N * L, D] | |
X = X_3D.view(-1, X_3D.size(-1)) | |
ctx.save_for_backward = X, W | |
# rowwise quantize for X | |
# columnwise quantize for W (first rowwise, transpose later) | |
X_int8, state_X = quantize_rowwise(X) | |
W_int8, state_W = quantize_rowwise(W) | |
# matmult, fused dequant and add bias | |
# call kernel which expects rowwise quantized X and W | |
return int8_matmul_rowwise_dequantize( | |
X_int8, W_int8.t(), state_X, state_W, bias | |
).view(*X_3D.size()[:-1], -1) | |
def backward(ctx, G_3D): | |
X, W = ctx.save_for_backward | |
G = G_3D.reshape(-1, G_3D.size(-1)) | |
grad_X = grad_W = grad_bias = None | |
if ctx.needs_input_grad[0]: | |
# rowwise quantize for G, columnwise quantize for W and fused transpose | |
# we call .t() for weight later because only A @ B^T is supported | |
G_int8, state_G = quantize_rowwise(G) | |
W_int8, state_W = quantize_columnwise_and_transpose(W) | |
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( | |
*G_3D.size()[:-1], -1 | |
) | |
if ctx.needs_input_grad[1]: | |
# backward pass uses standard weight grad | |
grad_W = torch.matmul(G.t(), X.to(G.dtype)) | |
if ctx.needs_input_grad[2]: | |
grad_bias = G.sum(dim=0) | |
return grad_X, grad_W, grad_bias | |
class _switchback_global_mem_efficient(torch.autograd.Function): | |
def forward(ctx, X_3D, W, bias): | |
# reshape input to [N * L, D] | |
X = X_3D.view(-1, X_3D.size(-1)) | |
X_3D_sz = X_3D.size() | |
# rowwise quantize for X, global quantize for W | |
X_int8, state_X = quantize_rowwise(X) | |
del X | |
W_int8, state_W = quantize_global(W) | |
# save for backward. | |
ctx.save_for_backward = X_int8, state_X, W_int8, state_W | |
# matmult, fused dequant and add bias | |
# call "mixed" because we are mixing rowwise quantized and global quantized | |
return int8_matmul_mixed_dequanitze( | |
X_int8, W_int8.t(), state_X, state_W, bias | |
).view(*X_3D_sz[:-1], -1) | |
def backward(ctx, G_3D): | |
# reshape input to [N_out * L, D] | |
G = G_3D.reshape(-1, G_3D.size(-1)) | |
G_3D_sz = G_3D.size() | |
grad_X = grad_W = grad_bias = None | |
X_int8, state_X, W_int8, state_W = ctx.save_for_backward | |
if ctx.needs_input_grad[1]: | |
real_X = dequantize_rowwise(X_int8, state_X) | |
del X_int8 | |
grad_W = torch.matmul(G.t(), real_X.to(G.dtype)) | |
del real_X | |
if ctx.needs_input_grad[2]: | |
grad_bias = G.sum(dim=0) | |
if ctx.needs_input_grad[0]: | |
G_int8, state_G = quantize_rowwise(G) | |
del G | |
W_int8 = W_int8.t().contiguous() | |
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( | |
*G_3D_sz[:-1], -1 | |
) | |
return grad_X, grad_W, grad_bias | |
class SwitchBackLinear(nn.Linear): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
bias: bool = True, | |
device=None, | |
dtype=None, | |
vector_wise_quantization: bool = False, | |
mem_efficient : bool = False, | |
): | |
super().__init__(in_features, out_features, bias, device, dtype) | |
if not is_triton_available: | |
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. | |
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') | |
# By default, we use the global quantization. | |
self.vector_wise_quantization = vector_wise_quantization | |
if self.vector_wise_quantization: | |
self._fn = _switchback_vectorrize | |
if mem_efficient: | |
print('mem efficient is not supported for vector-wise quantization.') | |
exit(1) | |
else: | |
if mem_efficient: | |
self._fn = _switchback_global_mem_efficient | |
else: | |
self._fn = _switchback_global | |
def prepare_for_eval(self): | |
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass. | |
# Note this is experimental and not tested thoroughly. | |
# Note this needs to be explicitly called with something like | |
# def cond_prepare(m): | |
# if hasattr(m, "prepare_for_eval"): | |
# m.prepare_for_eval() | |
# model.apply(cond_prepare) | |
print('=> preparing for eval.') | |
if self.vector_wise_quantization: | |
W_int8, state_W = quantize_rowwise(self.weight) | |
else: | |
W_int8, state_W = quantize_global(self.weight) | |
self.register_buffer("W_int8", W_int8) | |
self.register_buffer("state_W", state_W) | |
del self.weight | |
def forward(self, x): | |
if self.training: | |
return self._fn.apply(x, self.weight, self.bias) | |
else: | |
# If it hasn't been "prepared for eval", run the standard forward pass. | |
if not hasattr(self, "W_int8"): | |
return self._fn.apply(x, self.weight, self.bias) | |
# Otherwise, use pre-computed weights. | |
X = x.view(-1, x.size(-1)) | |
X_int8, state_X = quantize_rowwise(X) | |
if self.vector_wise_quantization: | |
return int8_matmul_rowwise_dequantize( | |
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias | |
).view(*x.size()[:-1], -1) | |
else: | |
return int8_matmul_mixed_dequanitze( | |
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias | |
).view(*x.size()[:-1], -1) | |
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) | |
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) | |
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) | |
# This is just the standard linear function. | |
class StandardLinearFunction(torch.autograd.Function): | |
def forward(ctx, input, weight, bias=None): | |
X = input.view(-1, input.size(-1)) | |
ctx.save_for_backward(X, weight, bias) | |
output = input.matmul(weight.t()) | |
if bias is not None: | |
output += bias.unsqueeze(0).expand_as(output) | |
return output.view(*input.size()[:-1], -1) | |
def backward(ctx, grad_output_3D): | |
input, weight, bias = ctx.saved_tensors | |
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) | |
grad_input = grad_weight = grad_bias = None | |
if ctx.needs_input_grad[0]: | |
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) | |
if ctx.needs_input_grad[1]: | |
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) | |
if bias is not None and ctx.needs_input_grad[2]: | |
grad_bias = grad_output.sum(0) | |
return grad_input, grad_weight, grad_bias | |
class StandardLinear(nn.Linear): | |
def forward(self, x): | |
return StandardLinearFunction.apply(x, self.weight, self.bias) | |