File size: 22,999 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import types
import logging
import torch
from . import default_layers
_logger = logging.getLogger(__name__)
class LayerInfo:
def __init__(self, name, module):
self.module = module
self.name = name
self.type = type(module).__name__
def _setattr(model, name, module):
name_list = name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor:
"""
Abstract base PyTorch compressor
"""
def __init__(self, model, config_list, optimizer=None):
"""
Record necessary info in class members
Parameters
----------
model : pytorch model
the model user wants to compress
config_list : list
the configurations that users specify for compression
optimizer: pytorch optimizer
optimizer used to train the model
"""
assert isinstance(model, torch.nn.Module)
self.validate_config(model, config_list)
self.bound_model = model
self.config_list = config_list
self.optimizer = optimizer
self.modules_to_compress = None
self.modules_wrapper = []
self.is_wrapped = False
self._fwd_hook_handles = {}
self._fwd_hook_id = 0
self.reset()
if not self.modules_wrapper:
_logger.warning('Nothing is configured to compress, please check your model and config_list')
def validate_config(self, model, config_list):
"""
subclass can optionally implement this method to check if config_list if valid
"""
pass
def reset(self, checkpoint=None):
"""
reset model state dict and model wrapper
"""
self._unwrap_model()
if checkpoint is not None:
self.bound_model.load_state_dict(checkpoint)
self.modules_to_compress = None
self.modules_wrapper = []
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
def _detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
if self.modules_to_compress is None:
self.modules_to_compress = []
for name, module in self.bound_model.named_modules():
if module == self.bound_model:
continue
layer = LayerInfo(name, module)
config = self.select_config(layer)
if config is not None:
self.modules_to_compress.append((layer, config))
return self.modules_to_compress
def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
def compress(self):
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
return self.bound_model
def set_wrappers_attribute(self, name, value):
"""
To register attributes used in wrapped module's forward method.
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper,
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper.
Parameters
----------
name : str
name of the variable
value: any
value of the variable
"""
for wrapper in self.get_modules_wrapper():
if isinstance(value, torch.Tensor):
wrapper.register_buffer(name, value.clone())
else:
setattr(wrapper, name, value)
def get_modules_to_compress(self):
"""
To obtain all the to-be-compressed modules.
Returns
-------
list
a list of the layers, each of which is a tuple (`layer`, `config`),
`layer` is `LayerInfo`, `config` is a `dict`
"""
return self.modules_to_compress
def get_modules_wrapper(self):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return self.modules_wrapper
def select_config(self, layer):
"""
Find the configuration for `layer` by parsing `self.config_list`
Parameters
----------
layer : LayerInfo
one layer
Returns
-------
config or None
the retrieved configuration for this layer, if None, this layer should
not be compressed
"""
ret = None
for config in self.config_list:
config = config.copy()
# expand config if key `default` is in config['op_types']
if 'op_types' in config and 'default' in config['op_types']:
expanded_op_types = []
for op_type in config['op_types']:
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
config['op_types'] = expanded_op_types
# check if condition is satisified
if 'op_types' in config and layer.type not in config['op_types']:
continue
if 'op_names' in config and layer.name not in config['op_names']:
continue
ret = config
if ret is None or 'exclude' in ret:
return None
return ret
def update_epoch(self, epoch):
"""
If user want to update model every epoch, user can override this method.
This method should be called at the beginning of each epoch
Parameters
----------
epoch : num
the current epoch number
"""
pass
def _wrap_modules(self, layer, config):
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
the configuration for compressing this layer
"""
raise NotImplementedError()
def add_activation_collector(self, collector):
self._fwd_hook_id += 1
self._fwd_hook_handles[self._fwd_hook_id] = []
for wrapper in self.get_modules_wrapper():
handle = wrapper.register_forward_hook(collector)
self._fwd_hook_handles[self._fwd_hook_id].append(handle)
return self._fwd_hook_id
def remove_activation_collector(self, fwd_hook_id):
if fwd_hook_id not in self._fwd_hook_handles:
raise ValueError("%s is not a valid collector id" % str(fwd_hook_id))
for handle in self._fwd_hook_handles[fwd_hook_id]:
handle.remove()
del self._fwd_hook_handles[fwd_hook_id]
def patch_optimizer(self, *tasks):
def patch_step(old_step):
def new_step(_, *args, **kwargs):
# call origin optimizer step method
output = old_step(*args, **kwargs)
# calculate mask
for task in tasks:
task()
return output
return new_step
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.pruner = pruner
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
def forward(self, *inputs):
# apply mask to weight, bias
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
class Pruner(Compressor):
"""
Prune to an exact pruning level specification
Attributes
----------
mask_dict : dict
Dictionary for saving masks, `key` should be layer name and
`value` should be a tensor which has the same shape with layer's weight
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
if optimizer is not None:
self.patch_optimizer(self.update_mask)
def compress(self):
self.update_mask()
return self.bound_model
def update_mask(self):
for wrapper_idx, wrapper in enumerate(self.get_modules_wrapper()):
masks = self.calc_mask(wrapper, wrapper_idx=wrapper_idx)
if masks is not None:
for k in masks:
assert hasattr(wrapper, k), "there is no attribute '%s' in wrapper" % k
setattr(wrapper, k, masks[k])
def calc_mask(self, wrapper, **kwargs):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
It will be applied with `mul()` operation on the weight.
This method is effectively hooked to `forward()` method of the model.
Parameters
----------
wrapper : Module
calculate mask for `wrapper.module`'s weight
"""
raise NotImplementedError("Pruners must overload calc_mask()")
def _wrap_modules(self, layer, config):
"""
Create a wrapper module to replace the original one.
Parameters
----------
layer : LayerInfo
the layer to instrument the mask
config : dict
the configuration for generating the mask
"""
_logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device)
return wrapper
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export pruned model weights, masks and onnx model(optional)
Parameters
----------
model_path : str
path to save pruned model state_dict
mask_path : str
(optional) path to save mask dict
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
assert model_path is not None, 'model_path must be specified'
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
for wrapper in self.get_modules_wrapper():
weight_mask = wrapper.weight_mask
bias_mask = wrapper.bias_mask
if weight_mask is not None:
mask_sum = weight_mask.sum().item()
mask_num = weight_mask.numel()
_logger.debug('Layer: %s Sparsity: %.4f', wrapper.name, 1 - mask_sum / mask_num)
wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
if bias_mask is not None:
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask}
torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None:
torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model()
def load_model_state_dict(self, model_state):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if self.is_wrapped:
self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
self.bound_model.load_state_dict(model_state)
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.quantizer = quantizer
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if 'weight' in config['quant_types']:
if not _check_weight(self.module):
_logger.warning('Module %s does not have parameter "weight"', self.name)
else:
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight))
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self)
result = self.module(*inputs)
else:
result = self.module(*inputs)
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self)
return result
return result
class Quantizer(Compressor):
"""
Base quantizer for pytorch quantizer
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad
if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer)
for wrapper in self.get_modules_wrapper():
if 'weight' in wrapper.config['quant_types']:
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})
def quantize_weight(self, weight, wrapper, **kwargs):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
weight : Tensor
weight that needs to be quantized
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_weight()')
def quantize_output(self, output, wrapper, **kwargs):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
output : Tensor
output that needs to be quantized
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, wrapper, **kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
inputs that needs to be quantized
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_input()')
def _wrap_modules(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
Parameters
----------
layer : LayerInfo
the layer to instrument the mask
config : dict
the configuration for quantization
"""
assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config'
assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type'
if isinstance(config['quant_bits'], dict):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
def step_with_optimizer(self):
pass
class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@staticmethod
def quant_backward(tensor, grad_output, quant_type):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
Parameters
----------
tensor : Tensor
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
Returns
-------
tensor
gradient of the input of quantization operation
"""
return grad_output
@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None
def _check_weight(module):
try:
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False
|