Spaces:
Running
Running
"""This file exports ONNX ops for opset 16. | |
Note [ONNX Operators that are added/updated in opset 16] | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set | |
New operators: | |
GridSample https://github.com/onnx/onnx/pull/3557 | |
Updated operators: | |
Identity | |
If | |
LeakyRelu | |
Loop | |
PRelu | |
RoiAlign | |
Scan | |
ScatterElements | |
ScatterND | |
Where | |
GreaterOrEqual | |
LessOrEqual | |
""" | |
# EDITING THIS FILE? READ THIS FIRST! | |
# see Note [Edit Symbolic Files] in README.md | |
import functools | |
import torch | |
from torch.nn.functional import ( | |
GRID_SAMPLE_INTERPOLATION_MODES, | |
GRID_SAMPLE_PADDING_MODES, | |
) | |
from torch.onnx import _type_utils, errors, symbolic_helper, utils | |
from torch.onnx._internal import _beartype, jit_utils, registration | |
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) | |
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? | |
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. | |
def grid_sampler( | |
g: jit_utils.GraphContext, | |
input, | |
grid, | |
mode_enum, | |
padding_mode_enum, | |
align_corners, | |
): | |
# Check the input and grid tensor rank beforehand. | |
if symbolic_helper._get_tensor_rank(input) == 5: | |
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") | |
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] | |
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg] | |
return g.op( | |
"GridSample", | |
input, | |
grid, | |
align_corners_i=int(align_corners), | |
mode_s=mode_s, | |
padding_mode_s=padding_mode_s, | |
) | |
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("scatter", self, dim, index, src, overload_name="src") | |
src_type = _type_utils.JitScalarType.from_value( | |
src, _type_utils.JitScalarType.UNDEFINED | |
) | |
src_sizes = symbolic_helper._get_tensor_sizes(src) | |
index_sizes = symbolic_helper._get_tensor_sizes(index) | |
if len(src_sizes) != len(index_sizes): | |
return symbolic_helper._unimplemented( | |
"scatter_add", | |
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", | |
) | |
# PyTorch only allows index shape <= src shape, so we can only consider | |
# taking index as subset size to src, like PyTorch does. When sizes for src | |
# and index are not matched or there are dynamic axes, we take index shape to | |
# slice src to accommodate. | |
if src_sizes != index_sizes or None in index_sizes: | |
adjusted_shape = g.op("Shape", index) | |
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) | |
src = g.op("Slice", src, starts, adjusted_shape) | |
src = symbolic_helper._maybe_get_scalar(src) | |
if symbolic_helper._is_value(src): | |
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") | |
else: | |
# Check if scalar "src" has same type as self (PyTorch allows different | |
# type for scalar src (but not when src is tensor)). If not, insert Cast node. | |
if _type_utils.JitScalarType.from_value(self) != src_type: | |
src = g.op( | |
"Cast", | |
src, | |
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), | |
) | |
return g.op( | |
"ScatterElements", | |
self, | |
index, | |
src, | |
axis_i=dim, | |
reduction_s="add", | |
) | |
def scatter_reduce( | |
g: jit_utils.GraphContext, | |
self: torch._C.Value, | |
dim: int, | |
index: torch._C.Value, | |
src: torch._C.Value, | |
reduce: str, | |
include_self: bool, | |
): | |
if reduce == "mean": | |
raise errors.OnnxExporterError( | |
"ONNX does not support mean reduction for scatter_reduce" | |
) | |
if not include_self: | |
raise errors.OnnxExporterError( | |
"ONNX does not support include_self=False for scatter_reduce" | |
) | |
reduce_mode = { # convert torch string name to onnx string name | |
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition | |
"sum": "add", | |
"prod": "mul", | |
"amin": "min", | |
"amax": "max", | |
} | |
onnx_reduce = reduce_mode[reduce] | |
self_rank = g.op("Size", g.op("Shape", self)) | |
# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) | |
self_rank_is_zero = g.op( | |
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) | |
) | |
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( | |
g, "If", self_rank_is_zero, n_blocks=2, outputs=3 | |
) | |
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) | |
self_reshape = if_context.op("Reshape", self, neg_1) | |
utils._add_output_to_block(if_context.block, self_reshape) | |
index_reshape = if_context.op("Reshape", index, neg_1) | |
utils._add_output_to_block(if_context.block, index_reshape) | |
src_reshape = if_context.op("Reshape", src, neg_1) | |
utils._add_output_to_block(if_context.block, src_reshape) | |
self_identity = else_context.op("Identity", self) | |
utils._add_output_to_block(else_context.block, self_identity) | |
index_identitye = else_context.op("Identity", index) | |
utils._add_output_to_block(else_context.block, index_identitye) | |
src_identity = else_context.op("Identity", src) | |
utils._add_output_to_block(else_context.block, src_identity) | |
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) | |
# if self_rank == 0: | |
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( | |
g, "If", self_rank_is_zero, n_blocks=2, outputs=1 | |
) | |
result_squeezed = if_context.op("Squeeze", result) | |
utils._add_output_to_block(if_context.block, result_squeezed) | |
result_identity = else_context.op("Identity", result) | |
utils._add_output_to_block(else_context.block, result_identity) | |
result_final = if_op.node().output() | |
return result_final | |