jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
# https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/9616415220fd09388622f40f6609e4ed81f048a5/mz_gguf_loader.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class quantize_lazy_load():
def __init__(self):
self.device = None
def __enter__(self):
self.device = torch.device("meta")
self.device.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.device.__exit__(exc_type, exc_value, traceback)
def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
if cublas_ops:
try:
from cublas_ops import cublas_half_matmul
linear_ops = cublas_half_matmul
setattr(model, "cublas_half_matmul", True)
print("Using cublas_ops")
except:
print("Failed to load cublas_ops")
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
else:
linear_ops = F.linear
setattr(model, "cublas_half_matmul", False)
quant_keys = []
for key in state_dict.keys():
if key.endswith(".Q4_0_qweight"):
quant_keys.append(key.replace(".Q4_0_qweight", ""))
qtype = "Q4_0"
elif key.endswith(".Q8_0_qweight"):
quant_keys.append(key.replace(".Q8_0_qweight", ""))
qtype = "Q8_0"
for name, module in model.named_modules():
if name in quant_keys:
#print(name)
q_linear = WQLinear_GGUF.from_linear(
linear=module,
device=device,
qtype=qtype,
linear_ops=linear_ops
)
set_op_by_name(model, name, q_linear)
model.to_empty(device=device)
model.load_state_dict(state_dict, strict=False)
return model
def set_op_by_name(layer, name, new_module):
levels = name.split(".")
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels) - 1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
class WQLinear_GGUF(nn.Module):
def __init__(
self, in_features, out_features, bias, dev, qtype, linear_ops
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.qtype = qtype
self.linear_ops = linear_ops
qweight_shape = quant_shape_to_byte_shape(
(out_features, in_features), qtype
)
self.register_buffer(
f"{qtype}_qweight",
torch.zeros(
qweight_shape,
dtype=torch.uint8,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear,
device="cpu",
qtype="Q4_0",
linear_ops=F.linear
):
q_linear = cls(
linear.in_features,
linear.out_features,
linear.bias is not None,
device,
qtype=qtype,
linear_ops=linear_ops
)
return q_linear
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
@torch.no_grad()
def forward(self, x):
if self.qtype == "Q4_0":
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype)
elif self.qtype == "Q8_0":
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
else:
raise ValueError(f"Unknown qtype: {self.qtype}")
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
def split_block_dims(blocks, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]:
# shape = shape[::-1]
block_size, type_size = GGML_QUANT_SIZES[qtype]
if shape[-1] % block_size != 0:
raise ValueError(
f"Quantized tensor row size ({shape[-1]}) is not a multiple of {qtype} block size ({block_size})")
return (*shape[:-1], shape[-1] // block_size * type_size)
def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]:
# shape = shape[::-1]
block_size, type_size = GGML_QUANT_SIZES[qtype]
if shape[-1] % type_size != 0:
raise ValueError(
f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {qtype} type size ({type_size})")
return (*shape[:-1], shape[-1] // type_size * block_size)
GGML_QUANT_SIZES = {
"Q4_0": (32, 2 + 16),
"Q8_0": (32, 2 + 32),
}
def dequantize_blocks_Q4_0(data, dtype=torch.float16):
block_size, type_size = GGML_QUANT_SIZES["Q4_0"]
data = data.to(torch.uint8)
shape = data.shape
rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = data.reshape((n_blocks, type_size))
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
out = (d * qs)
out = out.reshape(quant_shape_from_byte_shape(
shape,
qtype="Q4_0",
)).to(dtype)
return out
def dequantize_blocks_Q8_0(data, dtype=torch.float16):
block_size, type_size = GGML_QUANT_SIZES["Q8_0"]
data = data.to(torch.uint8)
shape = data.shape
rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = data.reshape((n_blocks, type_size))
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(torch.float32)
qs = qs.view(torch.int8).to(torch.float32)
out = (d * qs)
out = out.reshape(quant_shape_from_byte_shape(
shape,
qtype="Q8_0",
)).to(dtype)
return out