|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
class quantize_lazy_load(): |
|
def __init__(self): |
|
self.device = None |
|
|
|
def __enter__(self): |
|
self.device = torch.device("meta") |
|
self.device.__enter__() |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
self.device.__exit__(exc_type, exc_value, traceback) |
|
|
|
|
|
def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False): |
|
if cublas_ops: |
|
try: |
|
from cublas_ops import cublas_half_matmul |
|
linear_ops = cublas_half_matmul |
|
setattr(model, "cublas_half_matmul", True) |
|
print("Using cublas_ops") |
|
except: |
|
print("Failed to load cublas_ops") |
|
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops") |
|
else: |
|
linear_ops = F.linear |
|
setattr(model, "cublas_half_matmul", False) |
|
|
|
quant_keys = [] |
|
for key in state_dict.keys(): |
|
if key.endswith(".Q4_0_qweight"): |
|
quant_keys.append(key.replace(".Q4_0_qweight", "")) |
|
qtype = "Q4_0" |
|
elif key.endswith(".Q8_0_qweight"): |
|
quant_keys.append(key.replace(".Q8_0_qweight", "")) |
|
qtype = "Q8_0" |
|
|
|
for name, module in model.named_modules(): |
|
if name in quant_keys: |
|
|
|
q_linear = WQLinear_GGUF.from_linear( |
|
linear=module, |
|
device=device, |
|
qtype=qtype, |
|
linear_ops=linear_ops |
|
) |
|
set_op_by_name(model, name, q_linear) |
|
|
|
model.to_empty(device=device) |
|
model.load_state_dict(state_dict, strict=False) |
|
return model |
|
|
|
|
|
def set_op_by_name(layer, name, new_module): |
|
levels = name.split(".") |
|
if len(levels) > 1: |
|
mod_ = layer |
|
for l_idx in range(len(levels) - 1): |
|
if levels[l_idx].isdigit(): |
|
mod_ = mod_[int(levels[l_idx])] |
|
else: |
|
mod_ = getattr(mod_, levels[l_idx]) |
|
setattr(mod_, levels[-1], new_module) |
|
else: |
|
setattr(layer, name, new_module) |
|
|
|
class WQLinear_GGUF(nn.Module): |
|
def __init__( |
|
self, in_features, out_features, bias, dev, qtype, linear_ops |
|
): |
|
super().__init__() |
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.qtype = qtype |
|
self.linear_ops = linear_ops |
|
|
|
qweight_shape = quant_shape_to_byte_shape( |
|
(out_features, in_features), qtype |
|
) |
|
self.register_buffer( |
|
f"{qtype}_qweight", |
|
torch.zeros( |
|
qweight_shape, |
|
dtype=torch.uint8, |
|
device=dev, |
|
), |
|
) |
|
if bias: |
|
self.register_buffer( |
|
"bias", |
|
torch.zeros( |
|
(out_features), |
|
dtype=torch.float16, |
|
device=dev, |
|
), |
|
) |
|
else: |
|
self.bias = None |
|
|
|
@classmethod |
|
def from_linear( |
|
cls, linear, |
|
device="cpu", |
|
qtype="Q4_0", |
|
linear_ops=F.linear |
|
): |
|
q_linear = cls( |
|
linear.in_features, |
|
linear.out_features, |
|
linear.bias is not None, |
|
device, |
|
qtype=qtype, |
|
linear_ops=linear_ops |
|
) |
|
return q_linear |
|
|
|
def extra_repr(self) -> str: |
|
return ( |
|
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( |
|
self.in_features, |
|
self.out_features, |
|
self.bias is not None, |
|
self.w_bit, |
|
self.group_size, |
|
) |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def forward(self, x): |
|
if self.qtype == "Q4_0": |
|
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype) |
|
elif self.qtype == "Q8_0": |
|
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype) |
|
else: |
|
raise ValueError(f"Unknown qtype: {self.qtype}") |
|
|
|
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None) |
|
|
|
|
|
def split_block_dims(blocks, *args): |
|
n_max = blocks.shape[1] |
|
dims = list(args) + [n_max - sum(args)] |
|
return torch.split(blocks, dims, dim=1) |
|
|
|
|
|
def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]: |
|
|
|
block_size, type_size = GGML_QUANT_SIZES[qtype] |
|
if shape[-1] % block_size != 0: |
|
raise ValueError( |
|
f"Quantized tensor row size ({shape[-1]}) is not a multiple of {qtype} block size ({block_size})") |
|
return (*shape[:-1], shape[-1] // block_size * type_size) |
|
|
|
|
|
def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]: |
|
|
|
block_size, type_size = GGML_QUANT_SIZES[qtype] |
|
if shape[-1] % type_size != 0: |
|
raise ValueError( |
|
f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {qtype} type size ({type_size})") |
|
return (*shape[:-1], shape[-1] // type_size * block_size) |
|
|
|
|
|
GGML_QUANT_SIZES = { |
|
"Q4_0": (32, 2 + 16), |
|
"Q8_0": (32, 2 + 32), |
|
} |
|
def dequantize_blocks_Q4_0(data, dtype=torch.float16): |
|
block_size, type_size = GGML_QUANT_SIZES["Q4_0"] |
|
|
|
data = data.to(torch.uint8) |
|
shape = data.shape |
|
|
|
rows = data.reshape( |
|
(-1, data.shape[-1]) |
|
).view(torch.uint8) |
|
|
|
n_blocks = rows.numel() // type_size |
|
blocks = data.reshape((n_blocks, type_size)) |
|
|
|
n_blocks = blocks.shape[0] |
|
|
|
d, qs = split_block_dims(blocks, 2) |
|
d = d.view(torch.float16) |
|
|
|
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( |
|
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) |
|
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 |
|
|
|
out = (d * qs) |
|
|
|
out = out.reshape(quant_shape_from_byte_shape( |
|
shape, |
|
qtype="Q4_0", |
|
)).to(dtype) |
|
return out |
|
def dequantize_blocks_Q8_0(data, dtype=torch.float16): |
|
block_size, type_size = GGML_QUANT_SIZES["Q8_0"] |
|
|
|
data = data.to(torch.uint8) |
|
shape = data.shape |
|
|
|
rows = data.reshape( |
|
(-1, data.shape[-1]) |
|
).view(torch.uint8) |
|
|
|
n_blocks = rows.numel() // type_size |
|
blocks = data.reshape((n_blocks, type_size)) |
|
|
|
n_blocks = blocks.shape[0] |
|
|
|
d, qs = split_block_dims(blocks, 2) |
|
d = d.view(torch.float16).to(torch.float32) |
|
|
|
qs = qs.view(torch.int8).to(torch.float32) |
|
|
|
out = (d * qs) |
|
|
|
out = out.reshape(quant_shape_from_byte_shape( |
|
shape, |
|
qtype="Q8_0", |
|
)).to(dtype) |
|
return out |