File size: 1,704 Bytes
4d061f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from torch import nn
from .quantization import BitLinear
def replace_linears_in_hf(
model, name_skip = 'lm_head'
):
"""
Replaces all instances of nn.Linear in the given model with BitLinear15b.
Args:
model (nn.Module): The model to modify.
Returns:
None
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear) and name != name_skip:
# Replace the nn.Linear with BitLinear matching in features and and out_features, and add it to the model
setattr(
model,
name,
BitLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
),
)
else:
# Recursively apply to child modules
replace_linears_in_hf(module)
# def final_quantization(
# model,
# ):
# for name, module in model.named_children():
# if isinstance(module, BitLinear):
# module.quantization()
# else:
# # Recursively apply to child modules
# final_quantization(module)
def final_quantization(model):
for name, module in model.named_children():
if isinstance(module, BitLinear):
# Cuantificar directamente los pesos y biases del módulo
module.weight.data = weight_quant(module.weight.data)
if module.bias is not None:
module.bias.data = activation_quant(module.bias.data, module.input_bits)
else:
# Recursivamente aplicar a los módulos hijos
final_quantization(module)
|