|
import operator |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
toq = torch.ops.quantized |
|
|
|
import torch.ao.nn.quantized as nnq |
|
import torch.ao.nn.quantized.dynamic as nnqd |
|
import torch.ao.nn.intrinsic.quantized as nniq |
|
import torch.ao.nn.intrinsic.quantized.dynamic as nniqd |
|
import torch.ao.nn.intrinsic.qat as nniqat |
|
import torch.ao.nn.intrinsic as nni |
|
import torch.ao.nn.qat as nnqat |
|
import torch.ao.nn.qat.dynamic as nnqatd |
|
from torch.ao.quantization.backend_config import get_native_backend_config |
|
import torch.ao.quantization.fx._lower_to_native_backend as \ |
|
_lower_to_native_backend |
|
import torch.ao.quantization.quantization_mappings as quantization_mappings |
|
|
|
from .ns_types import NSNodeTargetType |
|
|
|
from typing import Callable, Dict, List, Optional, Set, Tuple |
|
|
|
|
|
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: |
|
|
|
sets_of_related_ops: List[Set[NSNodeTargetType]] = [ |
|
|
|
{ |
|
nn.Conv1d, |
|
}, |
|
{ |
|
nn.Conv2d, |
|
}, |
|
{ |
|
nn.Conv3d, |
|
}, |
|
|
|
{ |
|
F.conv1d, |
|
}, |
|
{ |
|
F.conv2d, |
|
}, |
|
{ |
|
F.conv3d, |
|
}, |
|
|
|
{ |
|
nn.Linear, |
|
}, |
|
|
|
{ |
|
F.linear, |
|
}, |
|
|
|
{ |
|
nn.AvgPool1d, |
|
torch.avg_pool1d, |
|
}, |
|
{ |
|
nn.AvgPool2d, |
|
torch._C._nn.avg_pool2d, |
|
}, |
|
{ |
|
nn.AvgPool3d, |
|
torch._C._nn.avg_pool3d, |
|
}, |
|
|
|
{ |
|
nn.AdaptiveAvgPool1d, |
|
F.adaptive_avg_pool1d, |
|
}, |
|
{ |
|
nn.AdaptiveAvgPool2d, |
|
F.adaptive_avg_pool2d, |
|
}, |
|
{ |
|
nn.AdaptiveAvgPool3d, |
|
F.adaptive_avg_pool3d, |
|
}, |
|
|
|
{ |
|
nn.LSTM, |
|
}, |
|
|
|
{ |
|
torch.add, |
|
operator.add, |
|
}, |
|
|
|
{ |
|
torch.cat, |
|
}, |
|
|
|
{ |
|
torch.mul, |
|
operator.mul, |
|
}, |
|
|
|
{ |
|
F.relu, |
|
nn.ReLU, |
|
'relu', |
|
'relu_', |
|
torch.relu, |
|
}, |
|
|
|
{ |
|
nn.MaxPool1d, |
|
F.max_pool1d, |
|
}, |
|
{ |
|
nn.MaxPool2d, |
|
F.max_pool2d, |
|
}, |
|
{ |
|
nn.MaxPool3d, |
|
F.max_pool3d, |
|
}, |
|
|
|
{ |
|
torch.sigmoid, |
|
'sigmoid', |
|
'sigmoid_', |
|
nn.Sigmoid, |
|
F.sigmoid, |
|
}, |
|
|
|
{ |
|
nn.BatchNorm2d, |
|
}, |
|
{ |
|
nn.BatchNorm3d, |
|
}, |
|
|
|
{ |
|
nn.ConvTranspose1d, |
|
}, |
|
{ |
|
nn.ConvTranspose2d, |
|
}, |
|
{ |
|
nn.ConvTranspose3d, |
|
}, |
|
|
|
{ |
|
F.conv_transpose1d, |
|
}, |
|
{ |
|
F.conv_transpose2d, |
|
}, |
|
{ |
|
F.conv_transpose3d, |
|
}, |
|
|
|
{ |
|
nn.ELU, |
|
}, |
|
|
|
{ |
|
nn.Embedding, |
|
}, |
|
|
|
{ |
|
nn.EmbeddingBag, |
|
}, |
|
|
|
{ |
|
nn.GroupNorm, |
|
}, |
|
|
|
{ |
|
nn.Hardswish, |
|
}, |
|
|
|
{ |
|
nn.InstanceNorm1d, |
|
}, |
|
{ |
|
nn.InstanceNorm2d, |
|
}, |
|
{ |
|
nn.InstanceNorm3d, |
|
}, |
|
|
|
{ |
|
nn.LayerNorm, |
|
}, |
|
|
|
{ |
|
nn.LeakyReLU, |
|
}, |
|
|
|
{ |
|
nn.ReLU6, |
|
F.relu6, |
|
}, |
|
|
|
{ |
|
F.elu, |
|
}, |
|
|
|
{ |
|
F.hardswish, |
|
}, |
|
|
|
{ |
|
F.group_norm, |
|
}, |
|
|
|
{ |
|
F.instance_norm, |
|
}, |
|
|
|
{ |
|
F.layer_norm, |
|
}, |
|
|
|
{ |
|
F.leaky_relu, |
|
}, |
|
|
|
{ |
|
nn.SiLU, |
|
F.silu, |
|
}, |
|
|
|
{ |
|
nn.Mish, |
|
F.mish, |
|
}, |
|
|
|
{ |
|
nn.Tanh, |
|
F.tanh, |
|
torch.tanh, |
|
'tanh_', |
|
'tanh', |
|
}, |
|
|
|
{ |
|
'hardsigmoid_', |
|
'hardsigmoid', |
|
F.hardsigmoid, |
|
nn.Hardsigmoid, |
|
}, |
|
|
|
{ |
|
nn.Hardtanh, |
|
F.hardtanh, |
|
F.hardtanh_, |
|
}, |
|
|
|
{ |
|
operator.floordiv, |
|
}, |
|
|
|
{ |
|
torch.unsqueeze, |
|
}, |
|
|
|
{ |
|
torch.stack, |
|
}, |
|
|
|
{ |
|
torch.squeeze, |
|
}, |
|
|
|
{ |
|
torch.sort, |
|
}, |
|
|
|
{ |
|
torch.repeat_interleave, |
|
}, |
|
|
|
{ |
|
torch.min, |
|
}, |
|
|
|
{ |
|
torch.mean, |
|
}, |
|
|
|
{ |
|
torch.max, |
|
}, |
|
|
|
{ |
|
torch.transpose, |
|
}, |
|
|
|
{ |
|
torch.flatten, |
|
}, |
|
|
|
{ |
|
torch.clamp, |
|
}, |
|
|
|
{ |
|
torch.chunk, |
|
}, |
|
|
|
{ |
|
torch.nn.functional.interpolate, |
|
}, |
|
|
|
{ |
|
nn.Dropout, |
|
}, |
|
|
|
{ |
|
F.dropout, |
|
}, |
|
|
|
{ |
|
torch.matmul, |
|
}, |
|
|
|
{ |
|
nn.Softmax, |
|
}, |
|
|
|
{ |
|
nn.PReLU, |
|
nnq.PReLU, |
|
}, |
|
|
|
{ |
|
F.prelu, |
|
toq.prelu, |
|
}, |
|
|
|
{ |
|
nn.PixelShuffle, |
|
}, |
|
{ |
|
F.pixel_shuffle, |
|
}, |
|
|
|
{ |
|
nn.PixelUnshuffle, |
|
}, |
|
{ |
|
F.pixel_unshuffle, |
|
}, |
|
|
|
{ |
|
torch.narrow, |
|
}, |
|
] |
|
|
|
|
|
|
|
backend_config = get_native_backend_config() |
|
|
|
new_connections: List[Tuple[Callable, Callable]] = [ |
|
|
|
(nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear), |
|
] |
|
|
|
for pattern, config in backend_config._pattern_complex_format_to_config.items(): |
|
|
|
|
|
first_element = pattern |
|
|
|
while isinstance(first_element, (list, tuple)): |
|
first_element = first_element[-1] |
|
|
|
if config.fused_module is not None: |
|
|
|
|
|
new_connections.append((first_element, config.fused_module)) |
|
|
|
if config.qat_module is not None: |
|
|
|
|
|
new_connections.append((first_element, config.qat_module)) |
|
|
|
if config.reference_quantized_module is not None: |
|
|
|
|
|
new_connections.append((first_element, config.reference_quantized_module)) |
|
|
|
|
|
|
|
|
|
|
|
for source_to_target in ( |
|
_lower_to_native_backend.STATIC_LOWER_MODULE_MAP, |
|
_lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP, |
|
_lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP, |
|
_lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP, |
|
): |
|
for source, target in source_to_target.items(): |
|
new_connections.append((source, target)) |
|
|
|
for source_to_double_target in ( |
|
_lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP, |
|
_lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP, |
|
_lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP, |
|
): |
|
for source, (target1, target2) in source_to_double_target.items(): |
|
new_connections.append((source, target1)) |
|
new_connections.append((source, target2)) |
|
|
|
|
|
|
|
|
|
|
|
for source, (target1, target2) in \ |
|
_lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): |
|
new_connections.append((source, target1)) |
|
new_connections.append((source, target2)) |
|
|
|
for source_to_target in ( |
|
_lower_to_native_backend.QBIN_OP_MAPPING, |
|
_lower_to_native_backend.QBIN_RELU_OP_MAPPING, |
|
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, |
|
): |
|
for source, target in source_to_target.items(): |
|
new_connections.append((source, target)) |
|
|
|
|
|
|
|
|
|
|
|
for source_to_target in ( |
|
quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, |
|
): |
|
for source, target in source_to_target.items(): |
|
new_connections.append((source, target)) |
|
|
|
|
|
|
|
for item1, item2 in new_connections: |
|
for set_of_related_ops in sets_of_related_ops: |
|
if item1 in set_of_related_ops or item2 in set_of_related_ops: |
|
set_of_related_ops.add(item1) |
|
set_of_related_ops.add(item2) |
|
break |
|
|
|
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {} |
|
|
|
counter = 0 |
|
for set_of_related_ops in sets_of_related_ops: |
|
base_name = str(counter) |
|
counter += 1 |
|
base_name_to_sets_of_related_ops[base_name] = set_of_related_ops |
|
|
|
return base_name_to_sets_of_related_ops |
|
|
|
|
|
def get_base_name_for_op( |
|
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], |
|
op: NSNodeTargetType, |
|
) -> Optional[str]: |
|
for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items(): |
|
if op in set_of_related_ops: |
|
return base_name |
|
return None |
|
|
|
|
|
def add_op_to_sets_of_related_ops( |
|
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], |
|
op: NSNodeTargetType, |
|
related_op: Optional[NSNodeTargetType], |
|
) -> None: |
|
if related_op is not None: |
|
for set_of_related_ops in base_name_to_sets_of_related_ops.values(): |
|
if related_op in set_of_related_ops: |
|
set_of_related_ops.add(op) |
|
return |
|
|
|
raise AssertionError(f"{related_op} was not found") |
|
else: |
|
counter = 0 |
|
while str(counter) in base_name_to_sets_of_related_ops: |
|
counter += 1 |
|
base_name_to_sets_of_related_ops[str(counter)] = {op} |
|
|
|
|
|
|
|
def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: |
|
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = { |
|
F.linear, |
|
F.conv1d, |
|
F.conv2d, |
|
F.conv3d, |
|
torch.cat, |
|
F.elu, |
|
F.hardswish, |
|
F.instance_norm, |
|
F.layer_norm, |
|
F.leaky_relu, |
|
F.dropout, |
|
F.silu, |
|
F.mish, |
|
operator.add, |
|
torch.add, |
|
operator.mul, |
|
torch.mul, |
|
torch.sum, |
|
F.prelu, |
|
} |
|
|
|
FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set() |
|
|
|
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = { |
|
toq.linear, |
|
toq.linear_relu, |
|
toq.conv1d, |
|
toq.conv1d_relu, |
|
toq.conv2d, |
|
toq.conv2d_relu, |
|
toq.conv3d, |
|
toq.conv3d_relu, |
|
toq.cat, |
|
toq.elu, |
|
toq.hardswish, |
|
toq.instance_norm, |
|
toq.layer_norm, |
|
toq.leaky_relu, |
|
toq.dropout, |
|
toq.prelu, |
|
|
|
|
|
|
|
|
|
} |
|
|
|
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { |
|
F.relu, |
|
F.tanh, |
|
torch.tanh, |
|
F.sigmoid, |
|
torch.sigmoid, |
|
F.hardsigmoid, |
|
operator.floordiv, |
|
torch.adaptive_avg_pool1d, |
|
F.adaptive_avg_pool2d, |
|
F.adaptive_avg_pool3d, |
|
F.dropout, |
|
F.hardtanh, |
|
F.hardtanh_, |
|
F.interpolate, |
|
F.max_pool1d, |
|
F.max_pool2d, |
|
F.max_pool3d, |
|
F.relu6, |
|
F.pixel_shuffle, |
|
F.pixel_unshuffle, |
|
torch.avg_pool1d, |
|
torch._C._nn.avg_pool2d, |
|
torch._C._nn.avg_pool3d, |
|
torch.cat, |
|
torch.chunk, |
|
torch.clamp, |
|
torch.flatten, |
|
torch.transpose, |
|
torch.max, |
|
torch.mean, |
|
torch.min, |
|
torch.narrow, |
|
torch.repeat_interleave, |
|
torch.sort, |
|
torch.squeeze, |
|
torch.stack, |
|
torch.unsqueeze, |
|
operator.add, |
|
} |
|
|
|
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = { |
|
nn.Linear, |
|
nnqat.Linear, |
|
nnqatd.Linear, |
|
nnqd.Linear, |
|
torch.nn.modules.linear.NonDynamicallyQuantizableLinear, |
|
nn.Conv1d, |
|
nn.Conv2d, |
|
nn.Conv3d, |
|
nnqat.Conv1d, |
|
nnqat.Conv2d, |
|
nnqat.Conv3d, |
|
nnqat.Embedding, |
|
nnqat.EmbeddingBag, |
|
nn.LSTM, |
|
|
|
|
|
nnqd.LSTM, |
|
nn.BatchNorm2d, |
|
nn.BatchNorm3d, |
|
nn.Dropout, |
|
nn.ConvTranspose1d, |
|
nn.ConvTranspose2d, |
|
nn.ConvTranspose3d, |
|
nn.ELU, |
|
nn.GroupNorm, |
|
nn.InstanceNorm1d, |
|
nn.InstanceNorm2d, |
|
nn.InstanceNorm3d, |
|
nn.LayerNorm, |
|
nn.Hardswish, |
|
nn.LeakyReLU, |
|
nn.ReLU6, |
|
nn.SiLU, |
|
nn.Mish, |
|
nn.Softmax, |
|
nn.PReLU, |
|
nni.BNReLU2d, |
|
nni.BNReLU3d, |
|
nni.ConvReLU1d, |
|
nni.ConvReLU2d, |
|
nni.ConvReLU3d, |
|
nni.LinearReLU, |
|
nni.LinearBn1d, |
|
nni.ConvBn1d, |
|
nni.ConvBn2d, |
|
nni.ConvBn3d, |
|
nniqat.ConvBn1d, |
|
nniqat.ConvBn2d, |
|
nniqat.ConvBn3d, |
|
nniqat.ConvBnReLU1d, |
|
nniqat.ConvBnReLU2d, |
|
nniqat.ConvBnReLU3d, |
|
nniqat.ConvReLU1d, |
|
nniqat.ConvReLU2d, |
|
nniqat.ConvReLU3d, |
|
nniqat.LinearReLU, |
|
nniqat.LinearBn1d, |
|
nniqd.LinearReLU, |
|
nni.LinearLeakyReLU, |
|
nni.LinearTanh, |
|
nni.ConvAdd2d, |
|
nni.ConvAddReLU2d, |
|
} |
|
|
|
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = { |
|
nnq.Linear, |
|
nnq.Conv1d, |
|
nnq.Conv2d, |
|
nnq.Conv3d, |
|
nnq.BatchNorm2d, |
|
nnq.BatchNorm3d, |
|
nnq.Dropout, |
|
nnq.ConvTranspose1d, |
|
nnq.ConvTranspose2d, |
|
nnq.ELU, |
|
nnq.InstanceNorm1d, |
|
nnq.InstanceNorm2d, |
|
nnq.InstanceNorm3d, |
|
nnq.LayerNorm, |
|
nnq.Hardswish, |
|
nnq.LeakyReLU, |
|
nnq.Embedding, |
|
nnq.EmbeddingBag, |
|
nnq.Dropout, |
|
nnq.Softmax, |
|
nnq.PReLU, |
|
nniq.BNReLU2d, |
|
nniq.BNReLU3d, |
|
nniq.ConvReLU1d, |
|
nniq.ConvReLU2d, |
|
nniq.ConvReLU3d, |
|
nniq.LinearReLU, |
|
nniq.LinearLeakyReLU, |
|
nniq.LinearTanh, |
|
nniq.ConvAdd2d, |
|
nniq.ConvAddReLU2d, |
|
} |
|
|
|
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { |
|
nn.ReLU, |
|
nn.Tanh, |
|
nn.Sigmoid, |
|
nn.Hardsigmoid, |
|
nn.AdaptiveAvgPool1d, |
|
nn.AdaptiveAvgPool2d, |
|
nn.AdaptiveAvgPool3d, |
|
nn.AvgPool1d, |
|
nn.AvgPool2d, |
|
nn.AvgPool3d, |
|
nn.Dropout, |
|
nn.Hardtanh, |
|
nn.Identity, |
|
nn.MaxPool1d, |
|
nn.MaxPool2d, |
|
nn.MaxPool3d, |
|
nn.PixelShuffle, |
|
nn.PixelUnshuffle, |
|
nn.ReLU6, |
|
} |
|
|
|
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { |
|
'sigmoid_', |
|
'sigmoid', |
|
'tanh_', |
|
'tanh', |
|
'hardsigmoid_', |
|
'hardsigmoid', |
|
'relu_', |
|
'relu', |
|
} |
|
|
|
return { |
|
'funs_io_type_fp32': FUNS_IO_TYPE_FP32, |
|
'funs_io_type_fp16': FUNS_IO_TYPE_FP16, |
|
'funs_io_type_int8': FUNS_IO_TYPE_INT8, |
|
'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8, |
|
'mods_io_type_fp32': MODS_IO_TYPE_FP32, |
|
'mods_io_type_int8': MODS_IO_TYPE_INT8, |
|
'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8, |
|
'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8, |
|
} |
|
|
|
|
|
def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]: |
|
|
|
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = { |
|
torch.quantize_per_tensor, |
|
operator.getitem, |
|
} |
|
|
|
MODS_UNMATCHABLE: Set[NSNodeTargetType] = { |
|
nn.Identity, |
|
} |
|
|
|
METHS_UNMATCHABLE: Set[NSNodeTargetType] = { |
|
'to', |
|
'dequantize', |
|
'reshape', |
|
'view', |
|
'unsqueeze_', |
|
'unsqueeze', |
|
'transpose', |
|
'squeeze_', |
|
'squeeze', |
|
'size', |
|
'shape', |
|
'resize_', |
|
'repeat_interleave', |
|
'repeat', |
|
'permute', |
|
'numel', |
|
'mean', |
|
'detach_', |
|
'detach', |
|
'contiguous', |
|
'clamp', |
|
'chunk', |
|
} |
|
|
|
return { |
|
'funs_unmatchable': FUNS_UNMATCHABLE, |
|
'mods_unmatchable': MODS_UNMATCHABLE, |
|
'meths_unmatchable': METHS_UNMATCHABLE, |
|
} |
|
|