Spaces:
Sleeping
Sleeping
File size: 9,843 Bytes
82fea12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
import torch
import torch.nn as nn
import time
from functools import partial
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
class _switchback_global(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X, W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]:
# rowwise quantize for G, global quantize for W
# for W, we also fuse the transpose operation because only A @ B^T is supported
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
ctx.save_for_backward = X, W
# rowwise quantize for X
# columnwise quantize for W (first rowwise, transpose later)
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_rowwise(W)
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
X, W = ctx.save_for_backward
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
if ctx.needs_input_grad[0]:
# rowwise quantize for G, columnwise quantize for W and fused transpose
# we call .t() for weight later because only A @ B^T is supported
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
X_3D_sz = X_3D.size()
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
del X
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
G_3D_sz = G_3D.size()
grad_X = grad_W = grad_bias = None
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
if ctx.needs_input_grad[1]:
real_X = dequantize_rowwise(X_int8, state_X)
del X_int8
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
del real_X
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
vector_wise_quantization: bool = False,
mem_efficient : bool = False,
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
# By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization:
self._fn = _switchback_vectorrize
if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.')
exit(1)
else:
if mem_efficient:
self._fn = _switchback_global_mem_efficient
else:
self._fn = _switchback_global
def prepare_for_eval(self):
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
# Note this is experimental and not tested thoroughly.
# Note this needs to be explicitly called with something like
# def cond_prepare(m):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight)
else:
W_int8, state_W = quantize_global(self.weight)
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
return self._fn.apply(x, self.weight, self.bias)
else:
# If it hasn't been "prepared for eval", run the standard forward pass.
if not hasattr(self, "W_int8"):
return self._fn.apply(x, self.weight, self.bias)
# Otherwise, use pre-computed weights.
X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise(X)
if self.vector_wise_quantization:
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
# This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
X = input.view(-1, input.size(-1))
ctx.save_for_backward(X, weight, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output.view(*input.size()[:-1], -1)
@staticmethod
def backward(ctx, grad_output_3D):
input, weight, bias = ctx.saved_tensors
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class StandardLinear(nn.Linear):
def forward(self, x):
return StandardLinearFunction.apply(x, self.weight, self.bias)
|