BitNet3-8B-Converted / linear_to_bitlinear.py
ejbejaranos's picture
Upload folder using huggingface_hub
4d061f7 verified
from torch import nn
from .quantization import BitLinear
def replace_linears_in_hf(
model, name_skip = 'lm_head'
):
"""
Replaces all instances of nn.Linear in the given model with BitLinear15b.
Args:
model (nn.Module): The model to modify.
Returns:
None
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear) and name != name_skip:
# Replace the nn.Linear with BitLinear matching in features and and out_features, and add it to the model
setattr(
model,
name,
BitLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
),
)
else:
# Recursively apply to child modules
replace_linears_in_hf(module)
# def final_quantization(
# model,
# ):
# for name, module in model.named_children():
# if isinstance(module, BitLinear):
# module.quantization()
# else:
# # Recursively apply to child modules
# final_quantization(module)
def final_quantization(model):
for name, module in model.named_children():
if isinstance(module, BitLinear):
# Cuantificar directamente los pesos y biases del m贸dulo
module.weight.data = weight_quant(module.weight.data)
if module.bias is not None:
module.bias.data = activation_quant(module.bias.data, module.input_bits)
else:
# Recursivamente aplicar a los m贸dulos hijos
final_quantization(module)