Spaces:
Running
Running
import importlib | |
import inspect | |
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 | |
from torch.onnx._internal import jit_utils, registration | |
def register_quantized_ops(domain: str, version: int): | |
# Register all quantized ops | |
module = importlib.import_module("torch.onnx.symbolic_caffe2") | |
quant_version_ops = inspect.getmembers(module) | |
aten_q_ops = { | |
"relu", | |
"_empty_affine_quantized", | |
"dequantize", | |
"quantize_per_tensor", | |
"upsample_nearest2d", | |
"avg_pool2d", | |
"reshape", | |
"slice", | |
"cat", | |
"max_pool2d", | |
"sigmoid", | |
} | |
for op, func in quant_version_ops: | |
name = f"{domain}::{op}" | |
if inspect.isfunction(func) and not registration.registry.is_registered_op( | |
name, version | |
): | |
if op in aten_q_ops: | |
# Override the builtin aten ops | |
registration.registry.register( | |
f"aten::{op}", version, func, custom=True | |
) | |
registration.registry.register(name, version, func) | |
def _permute_helper(g: jit_utils.GraphContext, input, axes): | |
quant_args = { | |
"axes_i": axes, | |
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | |
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | |
} | |
output = g.op("_caffe2::Int8Transpose", input, **quant_args) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def nchw2nhwc(g: jit_utils.GraphContext, input): | |
axes = [0, 2, 3, 1] | |
return _permute_helper(g, input, axes) | |
def nhwc2nchw(g: jit_utils.GraphContext, input): | |
axes = [0, 3, 1, 2] | |
return _permute_helper(g, input, axes) | |
def linear_prepack(g: jit_utils.GraphContext, weight, bias): | |
# Mapping to a dummy caffe2 prepack node. | |
# During the onnx -> c2 conversion we can look up original weight and bias | |
# from this node | |
output = g.op("_caffe2::WeightPrepack", weight, bias) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): | |
kwargs = { | |
"Y_scale_f": scale, | |
"Y_zero_point_i": zero_point, | |
} | |
output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def conv_prepack( | |
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups | |
): | |
# Mapping to a dummy caffe2 prepack node. | |
# During the onnx -> c2 conversion we can look up original weight and bias | |
# from this node | |
output = g.op("_caffe2::WeightPrepack", input, weight, bias) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def conv2d( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
scale, | |
zero_point, | |
): | |
kernel_size = weight.node()["shape"][1:3] | |
kwargs = { | |
"strides_i": stride, | |
"pads_i": padding + padding, | |
"dilations_i": dilation, | |
"group_i": groups, | |
"kernels_i": kernel_size, | |
"order_s": "NHWC", | |
"Y_scale_f": scale, | |
"Y_zero_point_i": zero_point, | |
} | |
output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def conv2d_relu( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
scale, | |
zero_point, | |
): | |
kernel_size = weight.node()["shape"][1:3] | |
kwargs = { | |
"strides_i": stride, | |
"pads_i": padding + padding, | |
"dilations_i": dilation, | |
"group_i": groups, | |
"kernels_i": kernel_size, | |
"order_s": "NHWC", | |
"Y_scale_f": scale, | |
"Y_zero_point_i": zero_point, | |
} | |
output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): | |
kwargs = { | |
"Y_scale_f": scale, | |
"Y_zero_point_i": zero_point, | |
} | |
output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def relu(g: jit_utils.GraphContext, input): | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.relu(g, input) | |
kwargs = { | |
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | |
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | |
} | |
output = g.op("_caffe2::Int8Relu", input, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): | |
kwargs = { | |
"Y_scale_f": scale, | |
"Y_zero_point_i": zero_point, | |
} | |
output = g.op("_caffe2::Int8Quantize", input, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def dequantize(g: jit_utils.GraphContext, input): | |
return g.op("_caffe2::Int8Dequantize", input) | |
def _empty_affine_quantized( | |
g: jit_utils.GraphContext, | |
input, | |
shape, | |
scale, | |
zero_point, | |
dtype, | |
pin_memory, | |
memory_format, | |
layout, | |
): | |
return input | |
def upsample_nearest2d( | |
g: jit_utils.GraphContext, | |
input, | |
output_size, | |
align_corners=None, | |
scales_h=None, | |
scales_w=None, | |
): | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] | |
output_size = symbolic_helper._parse_arg(output_size, "is") | |
kwargs = { | |
"output_size_i": output_size, | |
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | |
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | |
} | |
input = nchw2nhwc(g, input) | |
output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) | |
output = nhwc2nchw(g, output) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def max_pool2d( | |
g: jit_utils.GraphContext, | |
input, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
ceil_mode, | |
): | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.max_pool2d( # type: ignore[attr-defined] | |
g, input, kernel_size, stride, padding, dilation, ceil_mode | |
) | |
kwargs = { | |
"strides_i": stride, | |
"pads_i": padding + padding, | |
"kernel_i": kernel_size[0], | |
"order_s": "NHWC", | |
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | |
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | |
} | |
input = nchw2nhwc(g, input) | |
output = g.op("_caffe2::Int8MaxPool", input, **kwargs) | |
output = nhwc2nchw(g, output) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def avg_pool2d( | |
g: jit_utils.GraphContext, | |
input, | |
kernel_size, | |
stride, | |
padding, | |
ceil_mode, | |
count_include_pad, | |
divisor_override=None, | |
): | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.avg_pool2d( # type: ignore[attr-defined] | |
g, | |
input, | |
kernel_size, | |
stride, | |
padding, | |
ceil_mode, | |
count_include_pad, | |
divisor_override, | |
) | |
kwargs = { | |
"strides_i": stride, | |
"pads_i": padding + padding, | |
"kernel_i": kernel_size[0], | |
"order_s": "NHWC", | |
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | |
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | |
} | |
input = nchw2nhwc(g, input) | |
output = g.op("_caffe2::Int8AveragePool", input, **kwargs) | |
output = nhwc2nchw(g, output) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def reshape(g: jit_utils.GraphContext, input, shape): | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.reshape(g, input, shape) | |
kwargs = { | |
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | |
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | |
} | |
output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def slice(g: jit_utils.GraphContext, input, dim, start, end, step): | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.slice(g, input, dim, start, end, step) | |
if step != 1: | |
raise RuntimeError("ONNX quantized slice export only works for step 1.") | |
start = symbolic_helper._parse_arg(start, "i") | |
end = symbolic_helper._parse_arg(end, "i") | |
dim = symbolic_helper._parse_arg(dim, "i") | |
kwargs = { | |
"start_idx_i": start, | |
"end_idx_i": end, | |
"dim_i": dim, | |
"Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | |
"Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | |
} | |
output = g.op("_caffe2::Int8Slice", input, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): | |
tensors = symbolic_helper._unpack_list(tensor_list) | |
input = tensors[0] | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.cat(g, tensor_list, dim) | |
dim = symbolic_helper._parse_arg(dim, "i") | |
kwargs = { | |
"Y_scale_f": tensors[0].node()["Y_scale"], | |
"Y_zero_point_i": tensors[0].node()["Y_zero_point"], | |
} | |
output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |
def sigmoid(g: jit_utils.GraphContext, input): | |
if input not in symbolic_helper._quantized_ops: | |
return opset9.sigmoid(g, input) | |
# Caffe2 expects the output scale to be 1/2^8 | |
# and output zero_point to be 0 (quint8 type) | |
out_scale = 1.0 / 256 | |
zero_point = 0 | |
kwargs = { | |
"Y_scale_f": out_scale, | |
"Y_zero_point_i": zero_point, | |
} | |
output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) | |
symbolic_helper._quantized_ops.add(output) | |
return output | |