|
|
|
|
|
|
|
import types |
|
import logging |
|
import torch |
|
from . import default_layers |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
class LayerInfo: |
|
def __init__(self, name, module): |
|
self.module = module |
|
self.name = name |
|
self.type = type(module).__name__ |
|
|
|
def _setattr(model, name, module): |
|
name_list = name.split(".") |
|
for name in name_list[:-1]: |
|
model = getattr(model, name) |
|
setattr(model, name_list[-1], module) |
|
|
|
|
|
class Compressor: |
|
""" |
|
Abstract base PyTorch compressor |
|
""" |
|
|
|
def __init__(self, model, config_list, optimizer=None): |
|
""" |
|
Record necessary info in class members |
|
|
|
Parameters |
|
---------- |
|
model : pytorch model |
|
the model user wants to compress |
|
config_list : list |
|
the configurations that users specify for compression |
|
optimizer: pytorch optimizer |
|
optimizer used to train the model |
|
""" |
|
assert isinstance(model, torch.nn.Module) |
|
self.validate_config(model, config_list) |
|
|
|
self.bound_model = model |
|
self.config_list = config_list |
|
self.optimizer = optimizer |
|
|
|
self.modules_to_compress = None |
|
self.modules_wrapper = [] |
|
self.is_wrapped = False |
|
|
|
self._fwd_hook_handles = {} |
|
self._fwd_hook_id = 0 |
|
|
|
self.reset() |
|
|
|
if not self.modules_wrapper: |
|
_logger.warning('Nothing is configured to compress, please check your model and config_list') |
|
|
|
def validate_config(self, model, config_list): |
|
""" |
|
subclass can optionally implement this method to check if config_list if valid |
|
""" |
|
pass |
|
|
|
def reset(self, checkpoint=None): |
|
""" |
|
reset model state dict and model wrapper |
|
""" |
|
self._unwrap_model() |
|
if checkpoint is not None: |
|
self.bound_model.load_state_dict(checkpoint) |
|
|
|
self.modules_to_compress = None |
|
self.modules_wrapper = [] |
|
|
|
for layer, config in self._detect_modules_to_compress(): |
|
wrapper = self._wrap_modules(layer, config) |
|
self.modules_wrapper.append(wrapper) |
|
|
|
self._wrap_model() |
|
|
|
def _detect_modules_to_compress(self): |
|
""" |
|
detect all modules should be compressed, and save the result in `self.modules_to_compress`. |
|
The model will be instrumented and user should never edit it after calling this method. |
|
""" |
|
if self.modules_to_compress is None: |
|
self.modules_to_compress = [] |
|
for name, module in self.bound_model.named_modules(): |
|
if module == self.bound_model: |
|
continue |
|
layer = LayerInfo(name, module) |
|
config = self.select_config(layer) |
|
if config is not None: |
|
self.modules_to_compress.append((layer, config)) |
|
return self.modules_to_compress |
|
|
|
def _wrap_model(self): |
|
""" |
|
wrap all modules that needed to be compressed |
|
|
|
""" |
|
for wrapper in reversed(self.get_modules_wrapper()): |
|
_setattr(self.bound_model, wrapper.name, wrapper) |
|
self.is_wrapped = True |
|
|
|
def _unwrap_model(self): |
|
""" |
|
unwrap all modules that needed to be compressed |
|
|
|
""" |
|
for wrapper in self.get_modules_wrapper(): |
|
_setattr(self.bound_model, wrapper.name, wrapper.module) |
|
self.is_wrapped = False |
|
|
|
def compress(self): |
|
""" |
|
Compress the model with algorithm implemented by subclass. |
|
|
|
The model will be instrumented and user should never edit it after calling this method. |
|
`self.modules_to_compress` records all the to-be-compressed layers |
|
|
|
Returns |
|
------- |
|
torch.nn.Module |
|
model with specified modules compressed. |
|
""" |
|
return self.bound_model |
|
|
|
def set_wrappers_attribute(self, name, value): |
|
""" |
|
To register attributes used in wrapped module's forward method. |
|
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper, |
|
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper. |
|
|
|
Parameters |
|
---------- |
|
name : str |
|
name of the variable |
|
value: any |
|
value of the variable |
|
""" |
|
for wrapper in self.get_modules_wrapper(): |
|
if isinstance(value, torch.Tensor): |
|
wrapper.register_buffer(name, value.clone()) |
|
else: |
|
setattr(wrapper, name, value) |
|
|
|
def get_modules_to_compress(self): |
|
""" |
|
To obtain all the to-be-compressed modules. |
|
|
|
Returns |
|
------- |
|
list |
|
a list of the layers, each of which is a tuple (`layer`, `config`), |
|
`layer` is `LayerInfo`, `config` is a `dict` |
|
""" |
|
return self.modules_to_compress |
|
|
|
def get_modules_wrapper(self): |
|
""" |
|
To obtain all the wrapped modules. |
|
|
|
Returns |
|
------- |
|
list |
|
a list of the wrapped modules |
|
""" |
|
return self.modules_wrapper |
|
|
|
def select_config(self, layer): |
|
""" |
|
Find the configuration for `layer` by parsing `self.config_list` |
|
|
|
Parameters |
|
---------- |
|
layer : LayerInfo |
|
one layer |
|
|
|
Returns |
|
------- |
|
config or None |
|
the retrieved configuration for this layer, if None, this layer should |
|
not be compressed |
|
""" |
|
ret = None |
|
for config in self.config_list: |
|
config = config.copy() |
|
|
|
if 'op_types' in config and 'default' in config['op_types']: |
|
expanded_op_types = [] |
|
for op_type in config['op_types']: |
|
if op_type == 'default': |
|
expanded_op_types.extend(default_layers.weighted_modules) |
|
else: |
|
expanded_op_types.append(op_type) |
|
config['op_types'] = expanded_op_types |
|
|
|
|
|
if 'op_types' in config and layer.type not in config['op_types']: |
|
continue |
|
if 'op_names' in config and layer.name not in config['op_names']: |
|
continue |
|
|
|
ret = config |
|
if ret is None or 'exclude' in ret: |
|
return None |
|
return ret |
|
|
|
def update_epoch(self, epoch): |
|
""" |
|
If user want to update model every epoch, user can override this method. |
|
This method should be called at the beginning of each epoch |
|
|
|
Parameters |
|
---------- |
|
epoch : num |
|
the current epoch number |
|
""" |
|
pass |
|
|
|
def _wrap_modules(self, layer, config): |
|
""" |
|
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` |
|
|
|
Parameters |
|
---------- |
|
layer : LayerInfo |
|
the layer to instrument the compression operation |
|
config : dict |
|
the configuration for compressing this layer |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
def add_activation_collector(self, collector): |
|
self._fwd_hook_id += 1 |
|
self._fwd_hook_handles[self._fwd_hook_id] = [] |
|
for wrapper in self.get_modules_wrapper(): |
|
handle = wrapper.register_forward_hook(collector) |
|
self._fwd_hook_handles[self._fwd_hook_id].append(handle) |
|
return self._fwd_hook_id |
|
|
|
def remove_activation_collector(self, fwd_hook_id): |
|
if fwd_hook_id not in self._fwd_hook_handles: |
|
raise ValueError("%s is not a valid collector id" % str(fwd_hook_id)) |
|
for handle in self._fwd_hook_handles[fwd_hook_id]: |
|
handle.remove() |
|
del self._fwd_hook_handles[fwd_hook_id] |
|
|
|
def patch_optimizer(self, *tasks): |
|
def patch_step(old_step): |
|
def new_step(_, *args, **kwargs): |
|
|
|
output = old_step(*args, **kwargs) |
|
|
|
for task in tasks: |
|
task() |
|
return output |
|
return new_step |
|
if self.optimizer is not None: |
|
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer) |
|
|
|
class PrunerModuleWrapper(torch.nn.Module): |
|
def __init__(self, module, module_name, module_type, config, pruner): |
|
""" |
|
Wrap an module to enable data parallel, forward method customization and buffer registeration. |
|
|
|
Parameters |
|
---------- |
|
module : pytorch module |
|
the module user wants to compress |
|
config : dict |
|
the configurations that users specify for compression |
|
module_name : str |
|
the name of the module to compress, wrapper module shares same name |
|
module_type : str |
|
the type of the module to compress |
|
pruner : Pruner |
|
the pruner used to calculate mask |
|
""" |
|
super().__init__() |
|
|
|
self.module = module |
|
self.name = module_name |
|
self.type = module_type |
|
|
|
self.config = config |
|
self.pruner = pruner |
|
|
|
|
|
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) |
|
if hasattr(self.module, 'bias') and self.module.bias is not None: |
|
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) |
|
else: |
|
self.register_buffer("bias_mask", None) |
|
|
|
def forward(self, *inputs): |
|
|
|
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask) |
|
if hasattr(self.module, 'bias') and self.module.bias is not None: |
|
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask) |
|
return self.module(*inputs) |
|
|
|
class Pruner(Compressor): |
|
""" |
|
Prune to an exact pruning level specification |
|
|
|
Attributes |
|
---------- |
|
mask_dict : dict |
|
Dictionary for saving masks, `key` should be layer name and |
|
`value` should be a tensor which has the same shape with layer's weight |
|
|
|
""" |
|
|
|
def __init__(self, model, config_list, optimizer=None): |
|
super().__init__(model, config_list, optimizer) |
|
if optimizer is not None: |
|
self.patch_optimizer(self.update_mask) |
|
|
|
def compress(self): |
|
self.update_mask() |
|
return self.bound_model |
|
|
|
def update_mask(self): |
|
for wrapper_idx, wrapper in enumerate(self.get_modules_wrapper()): |
|
masks = self.calc_mask(wrapper, wrapper_idx=wrapper_idx) |
|
if masks is not None: |
|
for k in masks: |
|
assert hasattr(wrapper, k), "there is no attribute '%s' in wrapper" % k |
|
setattr(wrapper, k, masks[k]) |
|
|
|
def calc_mask(self, wrapper, **kwargs): |
|
""" |
|
Pruners should overload this method to provide mask for weight tensors. |
|
The mask must have the same shape and type comparing to the weight. |
|
It will be applied with `mul()` operation on the weight. |
|
This method is effectively hooked to `forward()` method of the model. |
|
|
|
Parameters |
|
---------- |
|
wrapper : Module |
|
calculate mask for `wrapper.module`'s weight |
|
""" |
|
raise NotImplementedError("Pruners must overload calc_mask()") |
|
|
|
def _wrap_modules(self, layer, config): |
|
""" |
|
Create a wrapper module to replace the original one. |
|
|
|
Parameters |
|
---------- |
|
layer : LayerInfo |
|
the layer to instrument the mask |
|
config : dict |
|
the configuration for generating the mask |
|
""" |
|
_logger.debug("Module detected to compress : %s.", layer.name) |
|
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) |
|
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name |
|
|
|
wrapper.to(layer.module.weight.device) |
|
return wrapper |
|
|
|
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None): |
|
""" |
|
Export pruned model weights, masks and onnx model(optional) |
|
|
|
Parameters |
|
---------- |
|
model_path : str |
|
path to save pruned model state_dict |
|
mask_path : str |
|
(optional) path to save mask dict |
|
onnx_path : str |
|
(optional) path to save onnx model |
|
input_shape : list or tuple |
|
input shape to onnx model |
|
device : torch.device |
|
device of the model, used to place the dummy input tensor for exporting onnx file. |
|
the tensor is placed on cpu if ```device``` is None |
|
""" |
|
assert model_path is not None, 'model_path must be specified' |
|
mask_dict = {} |
|
self._unwrap_model() |
|
|
|
for wrapper in self.get_modules_wrapper(): |
|
weight_mask = wrapper.weight_mask |
|
bias_mask = wrapper.bias_mask |
|
if weight_mask is not None: |
|
mask_sum = weight_mask.sum().item() |
|
mask_num = weight_mask.numel() |
|
_logger.debug('Layer: %s Sparsity: %.4f', wrapper.name, 1 - mask_sum / mask_num) |
|
wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask) |
|
if bias_mask is not None: |
|
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask) |
|
|
|
mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask} |
|
|
|
torch.save(self.bound_model.state_dict(), model_path) |
|
_logger.info('Model state_dict saved to %s', model_path) |
|
if mask_path is not None: |
|
torch.save(mask_dict, mask_path) |
|
_logger.info('Mask dict saved to %s', mask_path) |
|
if onnx_path is not None: |
|
assert input_shape is not None, 'input_shape must be specified to export onnx model' |
|
|
|
if device is None: |
|
device = torch.device('cpu') |
|
input_data = torch.Tensor(*input_shape) |
|
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) |
|
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) |
|
|
|
self._wrap_model() |
|
|
|
def load_model_state_dict(self, model_state): |
|
""" |
|
Load the state dict saved from unwrapped model. |
|
|
|
Parameters: |
|
----------- |
|
model_state : dict |
|
state dict saved from unwrapped model |
|
""" |
|
if self.is_wrapped: |
|
self._unwrap_model() |
|
self.bound_model.load_state_dict(model_state) |
|
self._wrap_model() |
|
else: |
|
self.bound_model.load_state_dict(model_state) |
|
|
|
class QuantizerModuleWrapper(torch.nn.Module): |
|
def __init__(self, module, module_name, module_type, config, quantizer): |
|
""" |
|
Wrap an module to enable data parallel, forward method customization and buffer registeration. |
|
|
|
Parameters |
|
---------- |
|
module : pytorch module |
|
the module user wants to compress |
|
config : dict |
|
the configurations that users specify for compression |
|
module_name : str |
|
the name of the module to compress, wrapper module shares same name |
|
module_type : str |
|
the type of the module to compress |
|
quantizer :quantizer |
|
the quantizer used to calculate mask |
|
""" |
|
super().__init__() |
|
|
|
self.module = module |
|
self.name = module_name |
|
self.type = module_type |
|
|
|
self.config = config |
|
self.quantizer = quantizer |
|
|
|
|
|
|
|
|
|
|
|
if 'weight' in config['quant_types']: |
|
if not _check_weight(self.module): |
|
_logger.warning('Module %s does not have parameter "weight"', self.name) |
|
else: |
|
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight)) |
|
delattr(self.module, 'weight') |
|
self.module.register_buffer('weight', self.module.old_weight) |
|
|
|
def forward(self, *inputs): |
|
if 'input' in self.config['quant_types']: |
|
inputs = self.quantizer.quant_grad.apply( |
|
inputs, |
|
QuantType.QUANT_INPUT, |
|
self) |
|
|
|
if 'weight' in self.config['quant_types'] and _check_weight(self.module): |
|
self.quantizer.quant_grad.apply( |
|
self.module.old_weight, |
|
QuantType.QUANT_WEIGHT, |
|
self) |
|
result = self.module(*inputs) |
|
else: |
|
result = self.module(*inputs) |
|
|
|
if 'output' in self.config['quant_types']: |
|
result = self.quantizer.quant_grad.apply( |
|
result, |
|
QuantType.QUANT_OUTPUT, |
|
self) |
|
return result |
|
return result |
|
|
|
class Quantizer(Compressor): |
|
""" |
|
Base quantizer for pytorch quantizer |
|
""" |
|
|
|
def __init__(self, model, config_list, optimizer=None): |
|
super().__init__(model, config_list, optimizer) |
|
self.quant_grad = QuantGrad |
|
if self.optimizer is not None: |
|
self.patch_optimizer(self.step_with_optimizer) |
|
for wrapper in self.get_modules_wrapper(): |
|
if 'weight' in wrapper.config['quant_types']: |
|
|
|
|
|
self.optimizer.add_param_group({"params": wrapper.module.old_weight}) |
|
|
|
def quantize_weight(self, weight, wrapper, **kwargs): |
|
""" |
|
quantize should overload this method to quantize weight. |
|
This method is effectively hooked to :meth:`forward` of the model. |
|
Parameters |
|
---------- |
|
weight : Tensor |
|
weight that needs to be quantized |
|
wrapper : QuantizerModuleWrapper |
|
the wrapper for origin module |
|
""" |
|
raise NotImplementedError('Quantizer must overload quantize_weight()') |
|
|
|
def quantize_output(self, output, wrapper, **kwargs): |
|
""" |
|
quantize should overload this method to quantize output. |
|
This method is effectively hooked to :meth:`forward` of the model. |
|
Parameters |
|
---------- |
|
output : Tensor |
|
output that needs to be quantized |
|
wrapper : QuantizerModuleWrapper |
|
the wrapper for origin module |
|
""" |
|
raise NotImplementedError('Quantizer must overload quantize_output()') |
|
|
|
def quantize_input(self, *inputs, wrapper, **kwargs): |
|
""" |
|
quantize should overload this method to quantize input. |
|
This method is effectively hooked to :meth:`forward` of the model. |
|
Parameters |
|
---------- |
|
inputs : Tensor |
|
inputs that needs to be quantized |
|
wrapper : QuantizerModuleWrapper |
|
the wrapper for origin module |
|
""" |
|
raise NotImplementedError('Quantizer must overload quantize_input()') |
|
|
|
|
|
def _wrap_modules(self, layer, config): |
|
""" |
|
Create a wrapper forward function to replace the original one. |
|
Parameters |
|
---------- |
|
layer : LayerInfo |
|
the layer to instrument the mask |
|
config : dict |
|
the configuration for quantization |
|
""" |
|
assert 'quant_types' in config, 'must provide quant_types in config' |
|
assert isinstance(config['quant_types'], list), 'quant_types must be list type' |
|
assert 'quant_bits' in config, 'must provide quant_bits in config' |
|
assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type' |
|
|
|
if isinstance(config['quant_bits'], dict): |
|
for quant_type in config['quant_types']: |
|
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type |
|
|
|
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) |
|
|
|
def step_with_optimizer(self): |
|
pass |
|
|
|
class QuantType: |
|
""" |
|
Enum class for quantization type. |
|
""" |
|
QUANT_INPUT = 0 |
|
QUANT_WEIGHT = 1 |
|
QUANT_OUTPUT = 2 |
|
|
|
|
|
class QuantGrad(torch.autograd.Function): |
|
""" |
|
Base class for overriding backward function of quantization operation. |
|
""" |
|
@staticmethod |
|
def quant_backward(tensor, grad_output, quant_type): |
|
""" |
|
This method should be overrided by subclass to provide customized backward function, |
|
default implementation is Straight-Through Estimator |
|
Parameters |
|
---------- |
|
tensor : Tensor |
|
input of quantization operation |
|
grad_output : Tensor |
|
gradient of the output of quantization operation |
|
quant_type : QuantType |
|
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, |
|
you can define different behavior for different types. |
|
Returns |
|
------- |
|
tensor |
|
gradient of the input of quantization operation |
|
""" |
|
return grad_output |
|
|
|
@staticmethod |
|
def forward(ctx, tensor, quant_type, wrapper, **kwargs): |
|
ctx.save_for_backward(tensor, torch.Tensor([quant_type])) |
|
if quant_type == QuantType.QUANT_INPUT: |
|
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) |
|
elif quant_type == QuantType.QUANT_WEIGHT: |
|
return wrapper.quantizer.quantize_weight(wrapper, **kwargs) |
|
elif quant_type == QuantType.QUANT_OUTPUT: |
|
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) |
|
else: |
|
raise ValueError("unrecognized QuantType.") |
|
|
|
@classmethod |
|
def backward(cls, ctx, grad_output): |
|
tensor, quant_type = ctx.saved_variables |
|
output = cls.quant_backward(tensor, grad_output, quant_type) |
|
return output, None, None, None |
|
|
|
def _check_weight(module): |
|
try: |
|
return isinstance(module.weight.data, torch.Tensor) |
|
except AttributeError: |
|
return False |
|
|