|
from torch import nn |
|
|
|
from .quantization import BitLinear |
|
|
|
|
|
def replace_linears_in_hf( |
|
model, name_skip = 'lm_head' |
|
): |
|
""" |
|
Replaces all instances of nn.Linear in the given model with BitLinear15b. |
|
|
|
Args: |
|
model (nn.Module): The model to modify. |
|
|
|
Returns: |
|
None |
|
""" |
|
for name, module in model.named_children(): |
|
if isinstance(module, nn.Linear) and name != name_skip: |
|
|
|
setattr( |
|
model, |
|
name, |
|
BitLinear( |
|
in_features=module.in_features, |
|
out_features=module.out_features, |
|
bias=module.bias is not None, |
|
), |
|
) |
|
else: |
|
|
|
replace_linears_in_hf(module) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def final_quantization(model): |
|
for name, module in model.named_children(): |
|
if isinstance(module, BitLinear): |
|
|
|
module.weight.data = weight_quant(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data = activation_quant(module.bias.data, module.input_bits) |
|
else: |
|
|
|
final_quantization(module) |
|
|