File size: 2,446 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from operator import attrgetter
import torch.distributed as dist
import torch.nn as nn
from ..pq.utils import attrsetter, get_layers
from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d}
def quantize_model_(model, p=0.2, bits=8, update_step=3000):
"""
Replaces all modules with their scalar quantized counterpart and
registers hooks to quantize the post-ativations of those modules.
Args:
- model: a nn.Module
- p: amount of noise (0 for no noise, 1 to quantize all the weights/activations)
- bits: number of bits
- update_step: update quantization parameters every update_step steps
"""
# quantize all layers
quantized_layers = get_layers(model, "(.*?)")
for layer in quantized_layers:
# book-keeping
is_master_process = (not dist.is_initialized()) or (
dist.is_initialized() and dist.get_rank() == 0
)
# recover module
module = attrgetter(layer)(model)
if is_master_process:
logging.info(
f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}"
)
# quantization params
q_params = {
"p": p,
"update_step": update_step,
"bits": bits,
"method": "histogram",
"counter": 0,
}
# instantiate the quantized counterpart
if isinstance(module, tuple(MAPPING.keys())):
QuantizedModule = MAPPING[module.__class__]
quantized_module = QuantizedModule.__new__(QuantizedModule)
params = module.__dict__
params.update(q_params)
quantized_module.__dict__.update(params)
else:
if is_master_process:
logging.info(f"Module {module} not yet supported for quantization")
continue
# activation quantization
a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram")
# replace layer by its quantized counterpart
attrsetter(layer)(model, quantized_module)
# return name of quantized layers
return quantized_layers
|