File size: 1,704 Bytes
4d061f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from torch import nn

from .quantization import BitLinear


def replace_linears_in_hf(
    model, name_skip = 'lm_head'
):
    """
    Replaces all instances of nn.Linear in the given model with BitLinear15b.

    Args:
        model (nn.Module): The model to modify.

    Returns:
        None
    """
    for name, module in model.named_children():
        if isinstance(module, nn.Linear) and name != name_skip:
            # Replace the nn.Linear with BitLinear matching in features and and out_features, and add it to the model
            setattr(
                model,
                name,
                BitLinear(
                    in_features=module.in_features,
                    out_features=module.out_features,
                    bias=module.bias is not None,
                ),
            )
        else:
            # Recursively apply to child modules
            replace_linears_in_hf(module)




# def final_quantization(
#     model,
# ):
#     for name, module in model.named_children():
#         if isinstance(module, BitLinear):
#             module.quantization()
#         else:
#             # Recursively apply to child modules
#             final_quantization(module)

def final_quantization(model):
    for name, module in model.named_children():
        if isinstance(module, BitLinear):
            # Cuantificar directamente los pesos y biases del módulo
            module.weight.data = weight_quant(module.weight.data)
            if module.bias is not None:
                module.bias.data = activation_quant(module.bias.data, module.input_bits)
        else:
            # Recursivamente aplicar a los módulos hijos
            final_quantization(module)