Spaces:
Running
Running
File size: 6,752 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
"""This file exports ONNX ops for opset 16.
Note [ONNX Operators that are added/updated in opset 16]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
New operators:
GridSample https://github.com/onnx/onnx/pull/3557
Updated operators:
Identity
If
LeakyRelu
Loop
PRelu
RoiAlign
Scan
ScatterElements
ScatterND
Where
GreaterOrEqual
LessOrEqual
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
import torch
from torch.nn.functional import (
GRID_SAMPLE_INTERPOLATION_MODES,
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import _type_utils, errors, symbolic_helper, utils
from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
@_beartype.beartype
def grid_sampler(
g: jit_utils.GraphContext,
input,
grid,
mode_enum,
padding_mode_enum,
align_corners,
):
# Check the input and grid tensor rank beforehand.
if symbolic_helper._get_tensor_rank(input) == 5:
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@_onnx_symbolic("aten::scatter_add")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")
src_type = _type_utils.JitScalarType.from_value(
src, _type_utils.JitScalarType.UNDEFINED
)
src_sizes = symbolic_helper._get_tensor_sizes(src)
index_sizes = symbolic_helper._get_tensor_sizes(index)
if len(src_sizes) != len(index_sizes):
return symbolic_helper._unimplemented(
"scatter_add",
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
)
# PyTorch only allows index shape <= src shape, so we can only consider
# taking index as subset size to src, like PyTorch does. When sizes for src
# and index are not matched or there are dynamic axes, we take index shape to
# slice src to accommodate.
if src_sizes != index_sizes or None in index_sizes:
adjusted_shape = g.op("Shape", index)
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
src = g.op("Slice", src, starts, adjusted_shape)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
if _type_utils.JitScalarType.from_value(self) != src_type:
src = g.op(
"Cast",
src,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return g.op(
"ScatterElements",
self,
index,
src,
axis_i=dim,
reduction_s="add",
)
@_onnx_symbolic("aten::scatter_reduce")
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
@_beartype.beartype
def scatter_reduce(
g: jit_utils.GraphContext,
self: torch._C.Value,
dim: int,
index: torch._C.Value,
src: torch._C.Value,
reduce: str,
include_self: bool,
):
if reduce == "mean":
raise errors.OnnxExporterError(
"ONNX does not support mean reduction for scatter_reduce"
)
if not include_self:
raise errors.OnnxExporterError(
"ONNX does not support include_self=False for scatter_reduce"
)
reduce_mode = { # convert torch string name to onnx string name
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
"sum": "add",
"prod": "mul",
"amin": "min",
"amax": "max",
}
onnx_reduce = reduce_mode[reduce]
self_rank = g.op("Size", g.op("Shape", self))
# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0)
self_rank_is_zero = g.op(
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
)
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=3
)
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
self_reshape = if_context.op("Reshape", self, neg_1)
utils._add_output_to_block(if_context.block, self_reshape)
index_reshape = if_context.op("Reshape", index, neg_1)
utils._add_output_to_block(if_context.block, index_reshape)
src_reshape = if_context.op("Reshape", src, neg_1)
utils._add_output_to_block(if_context.block, src_reshape)
self_identity = else_context.op("Identity", self)
utils._add_output_to_block(else_context.block, self_identity)
index_identitye = else_context.op("Identity", index)
utils._add_output_to_block(else_context.block, index_identitye)
src_identity = else_context.op("Identity", src)
utils._add_output_to_block(else_context.block, src_identity)
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)
# if self_rank == 0:
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=1
)
result_squeezed = if_context.op("Squeeze", result)
utils._add_output_to_block(if_context.block, result_squeezed)
result_identity = else_context.op("Identity", result)
utils._add_output_to_block(else_context.block, result_identity)
result_final = if_op.node().output()
return result_final
|