# Copyright 2024 The HuggingFace Team and City96. All rights reserved. # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # # You may obtain a copy of the License at # # # # http://www.apache.org/licenses/LICENSE-2.0 # # # # Unless required by applicable law or agreed to in writing, software # # distributed under the License is distributed on an "AS IS" BASIS, # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # See the License for the specific language governing permissions and # # limitations under the License. import inspect from contextlib import nullcontext import gguf import torch import torch.nn as nn from ...utils import is_accelerate_available if is_accelerate_available(): import accelerate from accelerate import init_empty_weights from accelerate.hooks import add_hook_to_module, remove_hook_from_module # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook def _create_accelerate_new_hook(old_hook): r""" Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with some changes """ old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) old_hook_attr = old_hook.__dict__ filtered_old_hook_attr = {} old_hook_init_signature = inspect.signature(old_hook_cls.__init__) for k in old_hook_attr.keys(): if k in old_hook_init_signature.parameters: filtered_old_hook_attr[k] = old_hook_attr[k] new_hook = old_hook_cls(**filtered_old_hook_attr) return new_hook def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]): def _should_convert_to_gguf(state_dict, prefix): weight_key = prefix + "weight" return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter) has_children = list(model.children()) if not has_children: return for name, module in model.named_children(): module_prefix = prefix + name + "." _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert) if ( isinstance(module, nn.Linear) and _should_convert_to_gguf(state_dict, module_prefix) and name not in modules_to_not_convert ): ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): model._modules[name] = GGUFLinear( module.in_features, module.out_features, module.bias is not None, compute_dtype=compute_dtype, ) model._modules[name].source_cls = type(module) # Force requires_grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) return model def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]): for name, module in model.named_children(): if isinstance(module, GGUFLinear) and name not in modules_to_not_convert: device = module.weight.device bias = getattr(module, "bias", None) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): new_module = nn.Linear( module.in_features, module.out_features, module.bias is not None, device=device, ) new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight)) if bias is not None: new_module.bias = bias # Create a new hook and attach it in case we use accelerate if hasattr(module, "_hf_hook"): old_hook = module._hf_hook new_hook = _create_accelerate_new_hook(old_hook) remove_hook_from_module(module) add_hook_to_module(new_module, new_hook) new_module.to(device) model._modules[name] = new_module has_children = list(module.children()) if has_children: _dequantize_gguf_and_restore_linear(module, modules_to_not_convert) return model # dequantize operations based on torch ports of GGUF dequantize_functions # from City96 # more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py QK_K = 256 K_SCALE_SIZE = 12 def to_uint32(x): x = x.view(torch.uint8).to(torch.int32) return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) def split_block_dims(blocks, *args): n_max = blocks.shape[1] dims = list(args) + [n_max - sum(args)] return torch.split(blocks, dims, dim=1) def get_scale_min(scales): n_blocks = scales.shape[0] scales = scales.view(torch.uint8) scales = scales.reshape((n_blocks, 3, 4)) d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): d, x = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) x = x.view(torch.int8) return d * x def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape((n_blocks, -1)) qs = ql | (qh << 4) return (d * qs) + m def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, qh, qs = split_block_dims(blocks, 2, 4) d = d.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape(n_blocks, -1) qs = (ql | (qh << 4)).to(torch.int8) - 16 return d * qs def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, m, qs = split_block_dims(blocks, 2, 2) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qs = (qs & 0x0F).reshape(n_blocks, -1) return (d * qs) + m def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, qs = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 return d * qs def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] ( ql, qh, scales, d, ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) scales = scales.view(torch.int8).to(dtype) d = d.view(torch.float16).to(dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( (1, 1, 4, 1) ) qh = (qh & 0x03).reshape((n_blocks, -1, 32)) q = (ql | (qh << 4)).to(torch.int8) - 32 q = q.reshape((n_blocks, QK_K // 16, -1)) return (d * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( (1, 1, 8, 1) ) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = (qh & 0x01).reshape((n_blocks, -1, 32)) q = ql | (qh << 4) return (d * q - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) d = d.view(torch.float16).to(dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 2, 1) ) lscales = lscales.reshape((n_blocks, 16)) hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor( [0, 2, 4, 6], device=d.device, dtype=torch.uint8 ).reshape((1, 4, 1)) hscales = hscales.reshape((n_blocks, 16)) scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) scales = scales.to(torch.int8) - 32 dl = (d * scales).reshape((n_blocks, 16, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( (1, 1, 4, 1) ) qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( (1, 1, 8, 1) ) ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 q = ql.to(torch.int8) - (qh << 2).to(torch.int8) return (dl * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 qs = qs.reshape((n_blocks, QK_K // 16, 16)) qs = dl * qs - ml return qs.reshape((n_blocks, -1)) def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES dequantize_functions = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, } SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys()) def _quant_shape_from_byte_shape(shape, type_size, block_size): return (*shape[:-1], shape[-1] // type_size * block_size) def dequantize_gguf_tensor(tensor): if not hasattr(tensor, "quant_type"): return tensor quant_type = tensor.quant_type dequant_fn = dequantize_functions[quant_type] block_size, type_size = GGML_QUANT_SIZES[quant_type] tensor = tensor.view(torch.uint8) shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size) n_blocks = tensor.numel() // type_size blocks = tensor.reshape((n_blocks, type_size)) dequant = dequant_fn(blocks, block_size, type_size) dequant = dequant.reshape(shape) return dequant.as_tensor() class GGUFParameter(torch.nn.Parameter): def __new__(cls, data, requires_grad=False, quant_type=None): data = data if data is not None else torch.empty(0) self = torch.Tensor._make_subclass(cls, data, requires_grad) self.quant_type = quant_type return self def as_tensor(self): return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} result = super().__torch_function__(func, types, args, kwargs) # When converting from original format checkpoints we often use splits, cats etc on tensors # this method ensures that the returned tensor type from those operations remains GGUFParameter # so that we preserve quant_type information quant_type = None for arg in args: if isinstance(arg, list) and (arg[0], GGUFParameter): quant_type = arg[0].quant_type break if isinstance(arg, GGUFParameter): quant_type = arg.quant_type break if isinstance(result, torch.Tensor): return cls(result, quant_type=quant_type) # Handle tuples and lists elif isinstance(result, (tuple, list)): # Preserve the original type (tuple or list) wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result] return type(result)(wrapped) else: return result class GGUFLinear(nn.Linear): def __init__( self, in_features, out_features, bias=False, compute_dtype=None, device=None, ) -> None: super().__init__(in_features, out_features, bias, device) self.compute_dtype = compute_dtype def forward(self, inputs): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) bias = self.bias.to(self.compute_dtype) output = torch.nn.functional.linear(inputs, weight, bias) return output