Spaces:
Running
Running
"""This file exports ONNX ops for opset 9. | |
Opset 9 is supported by ONNX release 1.4.1 | |
release on 01/23/19 | |
""" | |
from __future__ import annotations | |
import builtins | |
import functools | |
import math | |
import sys | |
import warnings | |
from typing import Callable, List, Optional, Sequence, Tuple, Union | |
import torch | |
import torch._C._onnx as _C_onnx | |
import torch.nn.modules.utils | |
import torch.onnx | |
from torch import _C | |
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics | |
from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper | |
from torch.onnx._globals import GLOBALS | |
from torch.onnx._internal import _beartype, jit_utils, registration | |
from torch.types import Number | |
# EDITING THIS FILE? READ THIS FIRST! | |
# see Note [Edit Symbolic Files] in README.md | |
__all__ = [ | |
"abs", | |
"acos", | |
"add", | |
"addcmul", | |
"addmm", | |
"alias", | |
"amax", | |
"amin", | |
"aminmax", | |
"arange", | |
"argmax", | |
"argmin", | |
"as_strided", | |
"as_tensor", | |
"asin", | |
"atan", | |
"atan2", | |
"baddbmm", | |
"batch_norm", | |
"bernoulli", | |
"bitwise_not", | |
"bitwise_or", | |
"bmm", | |
"broadcast_tensors", | |
"broadcast_to", | |
"bucketize", | |
"cat", | |
"cdist", | |
"ceil", | |
"clamp_max", | |
"clamp_min", | |
"clamp", | |
"clone", | |
"constant_pad_nd", | |
"contiguous", | |
"conv_tbc", | |
"conv_transpose1d", | |
"conv_transpose2d", | |
"conv_transpose3d", | |
"conv1d", | |
"conv2d", | |
"conv3d", | |
"convert_element_type", | |
"convolution", | |
"cos", | |
"cosine_similarity", | |
"cross", | |
"cumsum", | |
"detach", | |
"dim", | |
"div", | |
"dot", | |
"dropout", | |
"elu", | |
"embedding_bag", | |
"embedding", | |
"empty_like", | |
"empty", | |
"eq", | |
"erf", | |
"exp", | |
"expand_as", | |
"expand", | |
"eye", | |
"fill", | |
"flatten", | |
"floor_divide", | |
"floor", | |
"floordiv", | |
"frobenius_norm", | |
"full_like", | |
"full", | |
"gather", | |
"ge", | |
"gelu", | |
"get_pool_ceil_padding", | |
"glu", | |
"group_norm", | |
"gt", | |
"hann_window", | |
"hardshrink", | |
"hardsigmoid", | |
"hardswish", | |
"hardtanh", | |
"index_add", | |
"index_copy", | |
"index_fill", | |
"index_put", | |
"index_select", | |
"index", | |
"instance_norm", | |
"is_floating_point", | |
"is_pinned", | |
"isnan", | |
"item", | |
"kl_div", | |
"layer_norm", | |
"le", | |
"leaky_relu", | |
"lerp", | |
"lift", | |
"linalg_cross", | |
"linalg_matrix_norm", | |
"linalg_norm", | |
"linalg_vector_norm", | |
"linear", | |
"linspace", | |
"log_sigmoid", | |
"log_softmax", | |
"log", | |
"log10", | |
"log1p", | |
"log2", | |
"logical_and", | |
"logical_not", | |
"logical_or", | |
"logical_xor", | |
"logit", | |
"logsumexp", | |
"lstm_cell", | |
"lstm", | |
"lt", | |
"masked_fill", | |
"masked_fill_", | |
"matmul", | |
"max_pool1d_with_indices", | |
"max_pool2d_with_indices", | |
"max_pool3d_with_indices", | |
"max", | |
"maximum", | |
"meshgrid", | |
"min", | |
"minimum", | |
"mish", | |
"mm", | |
"movedim", | |
"mse_loss", | |
"mul", | |
"multinomial", | |
"mv", | |
"narrow", | |
"native_layer_norm", | |
"ne", | |
"neg", | |
"new_empty", | |
"new_full", | |
"new_ones", | |
"new_zeros", | |
"nonzero_numpy", | |
"nonzero", | |
"norm", | |
"numel", | |
"numpy_T", | |
"one_hot", | |
"ones_like", | |
"ones", | |
"onnx_placeholder", | |
"overload_by_arg_count", | |
"pad", | |
"pairwise_distance", | |
"permute", | |
"pixel_shuffle", | |
"pixel_unshuffle", | |
"pow", | |
"prelu", | |
"prim_constant_chunk", | |
"prim_constant_split", | |
"prim_constant", | |
"prim_data", | |
"prim_device", | |
"prim_dtype", | |
"prim_if", | |
"prim_layout", | |
"prim_list_construct", | |
"prim_list_unpack", | |
"prim_loop", | |
"prim_max", | |
"prim_min", | |
"prim_shape", | |
"prim_tolist", | |
"prim_tuple_construct", | |
"prim_type", | |
"prim_unchecked_cast", | |
"prim_uninitialized", | |
"rand_like", | |
"rand", | |
"randint_like", | |
"randint", | |
"randn_like", | |
"randn", | |
"reciprocal", | |
"reflection_pad", | |
"relu", | |
"relu6", | |
"remainder", | |
"repeat_interleave", | |
"repeat", | |
"replication_pad", | |
"reshape_as", | |
"reshape", | |
"roll", | |
"rrelu", | |
"rsqrt", | |
"rsub", | |
"scalar_tensor", | |
"scatter_add", | |
"scatter", | |
"select", | |
"selu", | |
"sigmoid", | |
"sign", | |
"silu", | |
"sin", | |
"size", | |
"slice", | |
"softmax", | |
"softplus", | |
"softshrink", | |
"sort", | |
"split_with_sizes", | |
"split", | |
"sqrt", | |
"square", | |
"squeeze", | |
"stack", | |
"std_mean", | |
"std", | |
"sub", | |
"t", | |
"take", | |
"tan", | |
"tanh", | |
"tanhshrink", | |
"tensor", | |
"threshold", | |
"to", | |
"topk", | |
"transpose", | |
"true_divide", | |
"type_as", | |
"unbind", | |
"unfold", | |
"unsafe_chunk", | |
"unsafe_split_with_sizes", | |
"unsafe_split", | |
"unsqueeze", | |
"unsupported_complex_operators", | |
"noop_complex_operators", | |
"unused", | |
"var_mean", | |
"var", | |
"view_as", | |
"view", | |
"where", | |
"wrap_logical_op_with_cast_to", | |
"wrap_logical_op_with_negation", | |
"zeros_like", | |
"zeros", | |
"zero", | |
] | |
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) | |
def _apply_params(*args, **kwargs): | |
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" | |
def _apply(fn): | |
return fn(*args, **kwargs) | |
return _apply | |
def _export(name: str): | |
"""Exports the function in the current global namespace.""" | |
def wrapper(func): | |
globals()[name] = func | |
__all__.append(name) | |
return func | |
return wrapper | |
def unused(g): | |
"""Represents "missing" optional inputs.""" | |
n = g.op("prim::Constant") | |
n.setType(_C.OptionalType.ofTensor()) | |
return n | |
def _shape_as_tensor(g: jit_utils.GraphContext, input): | |
return g.op("Shape", input) | |
def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): | |
if isinstance(shape, list): | |
shape = g.op("Concat", *shape, axis_i=0) | |
return reshape(g, input, shape) | |
def reshape(g: jit_utils.GraphContext, self, shape): | |
return symbolic_helper._reshape_helper(g, self, shape) | |
def reshape_as(g: jit_utils.GraphContext, self, other): | |
shape = g.op("Shape", other) | |
return reshape(g, self, shape) | |
def add(g: jit_utils.GraphContext, self, other, alpha=None): | |
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"Add", 9, 11, "Add between list of tensors not supported", self | |
) | |
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: | |
other = g.op("Mul", other, alpha) | |
return g.op("Add", self, other) | |
def sub(g: jit_utils.GraphContext, self, other, alpha=None): | |
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: | |
other = g.op("Mul", other, alpha) | |
return g.op("Sub", self, other) | |
def rsub(g: jit_utils.GraphContext, self, other, alpha=None): | |
return sub(g, other, self, alpha=alpha) | |
def mul(g: jit_utils.GraphContext, self, other): | |
if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): | |
# ONNX Mul doesn't support Boolean, so use And as an equivalent operator. | |
return g.op("And", self, other) | |
else: | |
return g.op("Mul", self, other) | |
def div(g: jit_utils.GraphContext, self, other, *args): | |
if len(args) == 0: | |
return true_divide(g, self, other) | |
else: | |
return _div_rounding_mode(g, self, other, *args) | |
def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): | |
value_tens = g.op("Constant", value_t=torch.tensor([value])) | |
return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) | |
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): | |
if rounding_mode is None: | |
return true_divide(g, self, other) | |
elif rounding_mode == "floor": | |
return _floor_divide(g, self, other) | |
elif rounding_mode == "trunc": | |
return _trunc_divide(g, self, other) | |
else: | |
raise errors.SymbolicValueError( | |
f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', | |
self, | |
) | |
def _trunc_divide(g: jit_utils.GraphContext, self, other): | |
out = g.op("Div", self, other) | |
# the correct operation is truncate, which is not supported in ONNX, | |
# we cannot call floor since it will behave differently for negative numbers | |
# (eg. -0.1 should become -0 ) | |
# - if scalar_type information are not available, assume that | |
# we need to call floor (treat as float) | |
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) | |
# Matching PyTorch's behavior: | |
# - if self is fp the output's type is self's type | |
# - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT | |
# - self is not fp and other is not fp, the output's type is self's output type | |
# - the output type defaults to Float | |
scalar_type = _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.UNDEFINED | |
) | |
if scalar_type != _type_utils.JitScalarType.UNDEFINED: | |
if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): | |
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
else: | |
out = g.op( | |
"Cast", | |
out, | |
to_i=scalar_type.onnx_type(), | |
) | |
else: | |
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
return out | |
def _floor_divide(g: jit_utils.GraphContext, self, other): | |
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): | |
out = true_divide(g, self, other) | |
return g.op("Floor", out) | |
else: | |
# Integer division does trunction rounding | |
div = g.op("Div", self, other) | |
# Division is negative if: self < 0 != other < 0 | |
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) | |
negative = g.op( | |
"Xor", | |
symbolic_helper._lt_helper(g, self, zero), | |
symbolic_helper._lt_helper(g, other, zero), | |
) | |
# For negative numbers with self % other != 0, subtract 1 to round down instead of up | |
mod = g.op("Sub", self, g.op("Mul", div, other)) | |
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) | |
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) | |
fixup = g.op("Mul", fixup_mask, one) | |
return g.op("Sub", div, fixup) | |
def floor_divide(g: jit_utils.GraphContext, self, other): | |
# Deprecated behavior, floor_divide actually truncates | |
return _trunc_divide(g, self, other) | |
def floordiv(g: jit_utils.GraphContext, self, other): | |
return floor_divide(g, self, other) | |
def true_divide(g: jit_utils.GraphContext, self, other): | |
"""Division where both inputs are cast to floating types | |
If both inputs are floating, performs div as usual | |
If only one input is a floating type, the other input is cast to its type | |
If neither input is a floating type, both inputs are cast to the default scalar type | |
""" | |
# Case 1: either values are floating | |
# Performs div as usual. | |
# Implicit casting will be handled in scalar type analysis pass. | |
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): | |
return g.op("Div", self, other) | |
# Case 2: neither is floating | |
# Casts both inputs to the default scalar type | |
scalar_type = torch.get_default_dtype() | |
onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT | |
assert scalar_type is torch.float or scalar_type is torch.double | |
if torch.get_default_dtype() is torch.double: | |
onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE | |
self = g.op("Cast", self, to_i=onnx_scalar_type) | |
other = g.op("Cast", other, to_i=onnx_scalar_type) | |
return g.op("Div", self, other) | |
def reciprocal(g: jit_utils.GraphContext, self): | |
# torch.reciprocal implicitly casts to float, so we do the same. | |
if not symbolic_helper._is_fp(self): | |
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
return g.op("Reciprocal", self) | |
def cat(g: jit_utils.GraphContext, tensor_list, dim): | |
tensors = symbolic_helper._unpack_list(tensor_list) | |
# torch.cat ignores empty tensors such as `torch.Tensor([])` | |
# These needs to be removed as input from ONNX's concat too, otherwise shape inference | |
# will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) | |
nonempty_tensors = [] | |
for t in tensors: | |
if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( | |
t, 0 | |
): | |
continue | |
nonempty_tensors.append(t) | |
assert len(nonempty_tensors) > 0 | |
assert all( | |
symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None | |
or symbolic_helper._get_tensor_rank(t) is None | |
or symbolic_helper._get_tensor_rank(t) | |
== symbolic_helper._get_tensor_rank(nonempty_tensors[0]) | |
for t in nonempty_tensors | |
) | |
tensor_list.node().removeAllInputs() | |
for t in nonempty_tensors: | |
tensor_list.node().addInput(t) | |
tensors = symbolic_helper._unpack_list(tensor_list) | |
return g.op("Concat", *tensors, axis_i=dim) | |
def stack(g: jit_utils.GraphContext, tensor_list, dim): | |
unsqueezed = [ | |
symbolic_helper._unsqueeze_helper(g, t, [dim]) | |
for t in symbolic_helper._unpack_list(tensor_list) | |
] | |
return g.op("Concat", *unsqueezed, axis_i=dim) | |
def _list(g: jit_utils.GraphContext, self): | |
return self | |
def mm(g: jit_utils.GraphContext, self, other): | |
# Create a dummy C tensor. Only needed for API purposes, the value is | |
# since beta = 0 | |
C = g.op("Constant", value_t=torch.tensor([1])) | |
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) | |
def bmm(g: jit_utils.GraphContext, self, other): | |
return g.op("MatMul", self, other) | |
def matmul(g: jit_utils.GraphContext, self, other): | |
return g.op("MatMul", self, other) | |
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): | |
scalar_type = None | |
self_scalar_type = symbolic_helper._try_get_scalar_type(self) | |
mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) | |
mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) | |
if self_scalar_type is not None: | |
scalar_type = self_scalar_type | |
elif mat1_scalar_type is not None: | |
scalar_type = mat1_scalar_type | |
elif mat2_scalar_type is not None: | |
scalar_type = mat2_scalar_type | |
mat1_rank = symbolic_helper._get_tensor_rank(mat1) | |
mat2_rank = symbolic_helper._get_tensor_rank(mat2) | |
def is_not_none_nor(v, u): | |
return v is not None and v != u | |
if scalar_type is not None and ( | |
is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) | |
): | |
res1 = g.op("MatMul", mat1, mat2) | |
res2 = self | |
alpha = symbolic_helper._scalar(alpha) | |
beta = symbolic_helper._scalar(beta) | |
if alpha != 1: | |
alpha = g.op( | |
"Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) | |
) | |
res1 = g.op("Mul", res1, alpha) | |
if beta != 1: | |
beta = g.op( | |
"Constant", | |
value_t=torch.tensor( | |
symbolic_helper._scalar(beta), dtype=scalar_type.dtype() | |
), | |
) | |
res2 = g.op("Mul", res2, beta) | |
return g.op("Add", res1, res2) | |
return g.op( | |
"Gemm", | |
mat1, | |
mat2, | |
self, | |
beta_f=symbolic_helper._scalar(beta), | |
alpha_f=symbolic_helper._scalar(alpha), | |
) | |
def neg(g: jit_utils.GraphContext, self): | |
return g.op("Neg", self) | |
def sqrt(g: jit_utils.GraphContext, self): | |
if _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.UNDEFINED | |
) in { | |
_type_utils.JitScalarType.UINT8, | |
_type_utils.JitScalarType.INT8, | |
_type_utils.JitScalarType.INT16, | |
_type_utils.JitScalarType.INT, | |
_type_utils.JitScalarType.INT64, | |
}: | |
# torch converts all int inputs to sqrt to float | |
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
return g.op("Sqrt", self) | |
def rsqrt(g: jit_utils.GraphContext, self): | |
return g.op( | |
"Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) | |
) | |
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp | |
def tanh(g: jit_utils.GraphContext, self): | |
return g.op("Tanh", self) | |
def sin(g: jit_utils.GraphContext, self): | |
return g.op("Sin", self) | |
def cos(g: jit_utils.GraphContext, self): | |
return g.op("Cos", self) | |
def tan(g: jit_utils.GraphContext, self): | |
return g.op("Tan", self) | |
def asin(g: jit_utils.GraphContext, self): | |
return g.op("Asin", self) | |
def acos(g: jit_utils.GraphContext, self): | |
return g.op("Acos", self) | |
def atan(g: jit_utils.GraphContext, self): | |
return g.op("Atan", self) | |
def atan2(g: jit_utils.GraphContext, self, other): | |
# self is y, and other is x on coordinate | |
slope = g.op("Div", self, other) | |
atan = g.op("Atan", slope) | |
const_zero = g.op("Constant", value_t=torch.tensor(0)) | |
const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) | |
condition_second_or_third_quadrant = g.op("Greater", self, const_zero) | |
second_third_quadrant = g.op( | |
"Where", | |
condition_second_or_third_quadrant, | |
g.op("Add", atan, const_pi), | |
g.op("Sub", atan, const_pi), | |
) | |
condition_14_or_23_quadrant = g.op("Less", other, const_zero) | |
result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) | |
return result | |
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp | |
def sigmoid(g: jit_utils.GraphContext, self): | |
return g.op("Sigmoid", self) | |
def sign(g: jit_utils.GraphContext, self): | |
return g.op("Sign", self) | |
def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): | |
assert len(starts) == len(ends) | |
if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: | |
return input | |
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) | |
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): | |
scalar_type = _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.UNDEFINED | |
) | |
if scalar_type != _type_utils.JitScalarType.UNDEFINED: | |
# This check only covers traced modules where dtype is present | |
# pytorch reduce-ops cast all other integral types to int64 | |
if ( | |
not symbolic_helper._is_fp(self) | |
and scalar_type != _type_utils.JitScalarType.INT64 | |
): | |
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) | |
return self | |
def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True): | |
def symbolic(g, self, dim=None, keepdim=None): | |
self = _maybe_cast_reduce_op_input(g, self) | |
if dim is None or dim == tuple(): | |
# Dim can be 0, which will cause (not dim) == True. So we don't want to do | |
# (not dim) | |
# all-reduce path | |
return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) | |
else: | |
# dim-reduce path | |
desc = "is" if allow_multi_dim_support else "i" | |
dim, keepdim = symbolic_helper._get_const( | |
dim, desc, "dim" | |
), symbolic_helper._get_const(keepdim, "i", "keepdim") | |
dim_list = dim if allow_multi_dim_support else [dim] | |
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) | |
return symbolic | |
def overload_by_arg_count(fn): | |
def wrapper(g, *args): | |
overloads = fn(g, *args) | |
for overload in overloads: | |
arg_descriptors = overload._arg_descriptors | |
if len(arg_descriptors) == len(args): | |
return overload(g, *args) | |
return symbolic_helper._unimplemented( | |
f"aten::{fn.__name__}", f"with {len(args)} arguments" | |
) | |
return wrapper | |
# torch.prod does not support multidimensional "dim" | |
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): | |
symbolic = _reduce_op_symbolic( | |
onnx_op, allow_multi_dim_support=allow_multi_dim_support | |
) | |
def reduce(g, *args, **kwargs): | |
def reduce_nodim(g, self, dtype): | |
dtype_onnx = None | |
if dtype.node().kind() == "onnx::Constant": | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() | |
self = g.op("Cast", self, to_i=dtype_onnx) | |
elif dtype.node().kind() != "prim::Constant": | |
return symbolic_helper._unimplemented(name, "dtype", dtype) | |
result = symbolic(g, self) | |
if dtype_onnx is not None: | |
result_dtype_onnx = _type_utils.JitScalarType.from_value( | |
result | |
).onnx_type() | |
if result_dtype_onnx != dtype_onnx: | |
result = g.op("Cast", result, to_i=dtype_onnx) | |
return result | |
dim_desc = "is" if allow_multi_dim_support else "i" | |
# type: ignore[arg-type] | |
def reduce_dim(g, self, dim, keepdim, dtype): | |
dtype_onnx = None | |
if dtype.node().kind() == "onnx::Constant": | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() | |
self = g.op("Cast", self, to_i=dtype_onnx) | |
elif dtype.node().kind() != "prim::Constant": | |
return symbolic_helper._unimplemented(name, "dtype", dtype) | |
result = symbolic(g, self, dim, keepdim) | |
if dtype_onnx is not None: | |
result_dtype_onnx = _type_utils.JitScalarType.from_value( | |
result | |
).onnx_type() | |
if result_dtype_onnx != dtype_onnx: | |
result = g.op("Cast", result, to_i=dtype_onnx) | |
return result | |
return reduce_nodim, reduce_dim | |
return reduce | |
def cumsum(g: jit_utils.GraphContext, input, dim, dtype): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
if dtype.node().kind() != "prim::Constant": | |
return symbolic_helper._unimplemented("cumsum", "dtype", dtype) | |
return g.at("cumsum", input, dim_i=dim) | |
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) | |
def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
if not symbolic_helper._is_none(generator): | |
return symbolic_helper._unimplemented( | |
"_sample_dirichlet", "We are not able to export generator", self | |
) | |
return g.at("_sample_dirichlet", self) | |
return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) | |
def _standard_gamma(g: jit_utils.GraphContext, self, generator): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
if not symbolic_helper._is_none(generator): | |
return symbolic_helper._unimplemented( | |
"_standard_gamma", "not able to export generator", self | |
) | |
return g.at("_standard_gamma", self) | |
return symbolic_helper._onnx_unsupported("_standard_gamma", self) | |
def t(g: jit_utils.GraphContext, self): | |
rank = symbolic_helper._get_tensor_rank(self) | |
if rank is None or rank < 2: | |
# The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior | |
# clearly and onnxruntime fails on these cases. So we add an Identity node to | |
# mirror the behavior of eager mode. | |
return g.op("Identity", self) | |
return g.op("Transpose", self, perm_i=(1, 0)) | |
def numpy_T(g: jit_utils.GraphContext, input): | |
ndim = symbolic_helper._get_tensor_rank(input) | |
assert ndim is not None | |
perm = list(reversed(range(0, ndim))) | |
return g.op("Transpose", input, perm_i=perm) | |
def expand(g: jit_utils.GraphContext, self, size, implicit): | |
size = symbolic_helper._maybe_get_const(size, "is") | |
if not symbolic_helper._is_value(size): | |
size = g.op("Constant", value_t=torch.LongTensor(size)) | |
elif symbolic_helper._is_packed_list(size): | |
# Expand with -1 dim value means dim is unchanged. | |
# Since onnx::expand supports two-way broadcasting, | |
# -1 dim value can be exported to onnx as 1 | |
size = symbolic_helper._reshape_helper( | |
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) | |
) | |
dtype = _type_utils.JitScalarType.INT64 | |
ones = ones_like(g, size, dtype) | |
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) | |
size = where(g, g.op("Equal", size, neg_ones), ones, size) | |
return g.op("Expand", self, size) | |
def broadcast_to(g: jit_utils.GraphContext, self, size): | |
size = symbolic_helper._maybe_get_const(size, "is") | |
if not symbolic_helper._is_value(size): | |
size = g.op("Constant", value_t=torch.LongTensor(size)) | |
elif symbolic_helper._is_packed_list(size): | |
# Expand with -1 dim value means dim is unchanged. | |
# Since onnx::expand supports two-way broadcasting, | |
# -1 dim value can be exported to onnx as 1 | |
size = symbolic_helper._reshape_helper( | |
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) | |
) | |
dtype = _type_utils.JitScalarType.INT64 | |
ones = ones_like(g, size, dtype) | |
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) | |
size = where(g, g.op("Equal", size, neg_ones), ones, size) | |
return g.op("Expand", self, size) | |
def expand_as(g: jit_utils.GraphContext, self, other): | |
self_t = symbolic_helper._maybe_get_const(self, "t") | |
if isinstance(self_t, torch.Tensor): | |
orig_type = self_t.dtype | |
self_t = self_t.to(torch.double) | |
dims = [] | |
for d in range(self_t.dim()): | |
if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): | |
dims.append(d) | |
self = g.op( | |
"Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) | |
) | |
shape = g.op("Shape", other) | |
return g.op("Expand", self, shape) | |
def embedding( | |
g: jit_utils.GraphContext, | |
weight, | |
indices, | |
padding_idx, | |
scale_grad_by_freq, | |
sparse, | |
): | |
if scale_grad_by_freq and GLOBALS.export_training: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of embedding with scale_grad_by_freq=True " | |
"for training mode. ONNX does not support scaling the gradients.", | |
weight, | |
) | |
if padding_idx >= 0 and GLOBALS.export_training: | |
warnings.warn( | |
"Warning: ONNX export of embedding with padding_idx >= 0 " | |
"for training mode. " | |
"ONNX does not support not updating the embedding vector at padding_idx during training." | |
) | |
return g.op("Gather", weight, indices) | |
def embedding_bag( | |
g: jit_utils.GraphContext, | |
embedding_matrix, | |
indices, | |
offsets, | |
scale_grad_by_freq, | |
mode, | |
sparse, | |
per_sample_weights, | |
include_last_offset, | |
padding_idx, | |
): | |
if not symbolic_helper._is_none(per_sample_weights): | |
return symbolic_helper._onnx_unsupported( | |
"embedding_bag with per_sample_weights" | |
) | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at( | |
"embedding_bag", | |
embedding_matrix, | |
indices, | |
offsets, | |
outputs=4, | |
scale_grad_by_freq_i=scale_grad_by_freq, | |
mode_i=mode, | |
sparse_i=sparse, | |
include_last_offset_i=include_last_offset, | |
padding_idx_i=padding_idx, | |
) | |
return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) | |
def size(g: jit_utils.GraphContext, self, dim=None): | |
if dim is None: | |
return g.op("Shape", self) | |
if symbolic_helper._maybe_get_const(dim, "i") < 0: | |
rank = symbolic_helper._get_tensor_rank(self) | |
if rank is not None: | |
dim = symbolic_helper._maybe_get_const(dim, "i") + rank | |
dim = g.op("Constant", value_t=torch.tensor(dim)) | |
return symbolic_helper._size_helper(g, self, dim) | |
def transpose(g: jit_utils.GraphContext, self, dim0, dim1): | |
if dim0 == dim1: # micro-optimization | |
return self | |
# NB: Transpose in ONNX is actually a Permute | |
rank = symbolic_helper._get_tensor_rank(self) | |
if rank is not None: | |
axes = list(range(rank)) | |
axes[dim0], axes[dim1] = axes[dim1], axes[dim0] | |
return g.op("Transpose", self, perm_i=axes) | |
elif symbolic_helper.is_caffe2_aten_fallback(): | |
# if we don't have dim information we cannot | |
# output a permute so use ATen instead | |
return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1) | |
else: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of transpose for tensor of unknown rank.", | |
self, | |
) | |
def permute(g: jit_utils.GraphContext, self, dims): | |
if dims == list(range(0, len(dims))): | |
return self | |
return g.op("Transpose", self, perm_i=dims) | |
def view(g: jit_utils.GraphContext, self, size): | |
return reshape(g, self, size) | |
def view_as(g: jit_utils.GraphContext, self, other): | |
shape = g.op("Shape", other) | |
return reshape(g, self, shape) | |
def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): | |
if _outputs is None: | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self | |
) | |
size = symbolic_helper._get_tensor_dim_size(self, dim) | |
if size is None: | |
return symbolic_helper._unimplemented( | |
"unsafe_chunk", "unknown dimension size", self | |
) | |
split_size = (size + chunks - 1) // chunks | |
splits = [split_size] * (size // split_size) | |
leftover = size % split_size | |
if leftover: | |
splits.append(leftover) | |
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) | |
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): | |
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"split", 9, 11, "Dynamic number of outputs not supported", self | |
) | |
split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") | |
if split_val.dim() > 0: | |
return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) | |
split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") | |
size = symbolic_helper._get_tensor_dim_size(self, dim) | |
if size is None: | |
if _outputs is not None: | |
size = split_size * _outputs | |
else: | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"split", 9, 11, "Unknown dimension size not supported", self | |
) | |
splits = [split_size] * (size // split_size) | |
leftover = size % split_size | |
if leftover: | |
splits.append(leftover) | |
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) | |
def unsafe_split( | |
g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None | |
): | |
return split(g, self, split_size_or_sizes, dim, _outputs) | |
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): | |
if not symbolic_helper._is_split_static(split_sizes, _outputs): | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self | |
) | |
return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) | |
def unsafe_split_with_sizes( | |
g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None | |
): | |
return split_with_sizes(g, self, split_sizes, dim, _outputs) | |
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): | |
if _outputs is None: | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"unbind", 9, 11, "Dynamic number of outputs not supported", self | |
) | |
outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) | |
outputs = [outputs] if _outputs == 1 else outputs | |
squeezed_outputs = [ | |
symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs | |
] | |
return squeezed_outputs | |
def select(g: jit_utils.GraphContext, self, dim, index): | |
index = symbolic_helper._maybe_get_scalar(index) | |
if (not symbolic_helper._is_value(index)) and (index < 0): | |
if index == -1: | |
end_index = _constants.INT64_MAX | |
else: | |
end_index = index + 1 | |
slice_node = symbolic_helper._slice_helper( | |
g, self, axes=[dim], starts=[index], ends=[end_index] | |
) | |
return symbolic_helper._squeeze_helper(g, slice_node, [dim]) | |
else: | |
# FIXME(justinchuby): can index be an int and not a value? | |
return g.op("Gather", self, index, axis_i=dim) | |
def square(g: jit_utils.GraphContext, self): | |
return g.op("Mul", self, self) | |
def squeeze(g: jit_utils.GraphContext, self, dim=None): | |
if dim is None: | |
return g.op("Squeeze", self) | |
squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") | |
# Handle negative dims | |
if squeeze_dim < 0: | |
rank = symbolic_helper._get_tensor_rank(self) | |
if rank is not None: | |
warnings.warn( | |
"ONNX export squeeze with negative axis " | |
+ str(squeeze_dim) | |
+ " might cause the onnx model to be incorrect. " | |
+ "Negative axis is not supported in ONNX. " | |
+ "Axis is converted to " | |
+ str(squeeze_dim + rank) | |
+ " based on input shape at export time. " | |
+ "Passing an tensor of different rank in execution will be incorrect." | |
) | |
squeeze_dim += rank | |
else: | |
return symbolic_helper._unimplemented( | |
"squeeze", "negative axis with unknown input rank", self | |
) | |
dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) | |
if dim_size is None: | |
warnings.warn( | |
"This model contains a squeeze operation on dimension " | |
+ str(squeeze_dim) | |
+ " on an input " | |
+ "with unknown shape. Note that if the size of dimension " | |
+ str(squeeze_dim) | |
+ " of the input " | |
+ "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " | |
+ "non-singleton dimensions, it is recommended to export this model using opset " | |
+ "version 11 or higher." | |
) | |
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) | |
if dim_size > 1: | |
warnings.warn( | |
"This model contains a squeeze operation on dimension " | |
+ str(squeeze_dim) | |
+ ". The size of " | |
+ "this dimension in the given input is " | |
+ str(dim_size) | |
+ ". The model will " | |
+ "be exported without the squeeze node. If the model is intended to be used with dynamic " | |
+ "input shapes, please use opset version 11 to " | |
+ "export the model." | |
) | |
return self | |
warnings.warn( | |
"This model contains a squeeze operation on dimension " | |
+ str(squeeze_dim) | |
+ ". If the model is " | |
+ "intended to be used with dynamic input shapes, please use opset version 11 to export the model." | |
) | |
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) | |
def prelu(g: jit_utils.GraphContext, self, weight): | |
self_rank = symbolic_helper._get_tensor_rank(self) | |
weight_sizes = symbolic_helper._get_tensor_sizes(weight) | |
weight_rank = len(weight_sizes) | |
if self_rank is not None: | |
if self_rank > 2: | |
# make weight unidirectional broadcastable | |
weight = symbolic_helper._unsqueeze_helper( | |
g, weight, list(range(1, self_rank - 1)) | |
) | |
elif self_rank == 0 and weight_sizes == [1]: | |
# self and weight are both scalar but weight has rank == 1, squeeze weight. | |
weight = symbolic_helper._squeeze_helper(g, weight, [0]) | |
weight_rank = 0 | |
if self_rank is not None and weight_rank is not None: | |
assert ( | |
self_rank >= weight_rank | |
), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" | |
return g.op("PRelu", self, weight) | |
def silu(g: jit_utils.GraphContext, input): | |
return g.op("Mul", input, g.op("Sigmoid", input)) | |
def mish(g: jit_utils.GraphContext, input): | |
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) | |
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): | |
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. | |
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch | |
operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic | |
`Clip<int>(INPUT)` (opset version < 12). | |
Args: | |
g (torch._C.Graph): graph to write the ONNX representation into. | |
op_name (str): operator name in ONNX. | |
*args (tuple): operands to the operator. | |
**kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) | |
indicating the smallest opset version to trigger such casting behavior and "target_float_t" | |
(optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. | |
Returns: | |
Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. | |
""" | |
opset_before = kwargs.pop("opset_before", None) | |
target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) | |
inputs = list(args) | |
dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) | |
require_cast = not symbolic_helper._is_fp(inputs[0]) and ( | |
opset_before is None or GLOBALS.export_onnx_opset_version < opset_before | |
) | |
if require_cast: | |
for input in inputs: | |
if input.isCompleteTensor(): | |
input_scalar_type = _type_utils.JitScalarType.from_value(input) | |
if input_scalar_type != dtype_0: | |
raise errors.SymbolicValueError( | |
f"Inputs of {op_name} must have same dtype." | |
f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", | |
input, | |
) | |
for i, input in enumerate(inputs): | |
if input.isCompleteTensor() and not symbolic_helper._is_fp(input): | |
inputs[i] = g.op( | |
"Cast", | |
input, | |
to_i=target_float_t.onnx_type(), | |
) | |
self = g.op(op_name, *inputs, **kwargs) | |
if require_cast: | |
self = g.op("Cast", self, to_i=dtype_0.onnx_type()) | |
return self | |
def relu(g: jit_utils.GraphContext, input): | |
return _op_with_optional_float_cast(g, "Relu", input, opset_before=14) | |
def relu6(g: jit_utils.GraphContext, input): | |
return clamp(g, input, 0, 6) | |
def ceil(g: jit_utils.GraphContext, input): | |
return g.op("Ceil", input) | |
def floor(g: jit_utils.GraphContext, input): | |
return g.op("Floor", input) | |
def _len(g: jit_utils.GraphContext, self): | |
sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) | |
return symbolic_helper._squeeze_helper(g, sz_0, [0]) | |
def threshold(g: jit_utils.GraphContext, self, threshold, value): | |
# See Note [Export inplace] | |
if symbolic_helper._scalar(threshold) != 0: | |
return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) | |
if symbolic_helper._scalar(value) != 0: | |
return symbolic_helper._unimplemented("threshold", "non-zero value", self) | |
return g.op("Relu", self) | |
def leaky_relu( | |
g: jit_utils.GraphContext, | |
input: _C.Value, | |
negative_slope: float, | |
inplace: bool = False, | |
): | |
# See Note [Export inplace] | |
return g.op("LeakyRelu", input, alpha_f=negative_slope) | |
def glu(g: jit_utils.GraphContext, input, dim): | |
dim_size = symbolic_helper._get_tensor_dim_size(input, dim) | |
if dim_size is not None: | |
assert dim_size % 2 == 0 | |
first, second = g.op("Split", input, axis_i=dim, outputs=2) | |
return g.op("Mul", first, g.op("Sigmoid", second)) | |
def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): | |
# Softmax does normalization at vector level. | |
# PyTorch and ONNX use different strategies to split the input tensor into vectors. | |
# Thus dim and axis have different meanings. | |
# PyTorch slices the input tensor into vectors along the `dim`-th dimension. | |
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. | |
# If input is a 2 x 3 tensor: | |
# input = [[1.0, 1.0, 1.0], | |
# [1.0, 1,0, 1,0]] | |
# with dim = 0, the result is: | |
# result = [[0.5, 0.5, 0.5], | |
# [0.5, 0.5, 0.5]] | |
# with axis = 0, the result is: | |
# result = [[0.167, 0.167, 0.167], | |
# [0.167, 0.167, 0.167]] | |
# So only when dim and axis both equal to ndim - 1 (the last dimension), | |
# their semantics are equivalent. | |
# So use softmax when dim and axis both equal to ndim - 1, | |
# otherwise transpose the input to put the vectors to be normalized to the last dimension. | |
# When input rank is not known at export time we compute softmax using a subgraph | |
# with other operators | |
input_dim = symbolic_helper._get_tensor_rank(input) | |
if input_dim is not None: | |
# TODO: remove this as onnx opset 11 spec allows negative axes | |
if dim < 0: | |
dim = input_dim + dim | |
is_transpose_required = input_dim != dim + 1 | |
if is_transpose_required: | |
axes = list(range(input_dim)) | |
axes[dim], axes[-1] = axes[-1], axes[dim] | |
input = g.op("Transpose", input, perm_i=axes) | |
dim = input_dim - 1 | |
softmax = g.op("Softmax", input, axis_i=dim) | |
if dtype and dtype.node().kind() != "prim::Constant": | |
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
softmax = g.op( | |
"Cast", | |
softmax, | |
to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), | |
) | |
if is_transpose_required: | |
softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] | |
return softmax | |
# Apply max normalization. | |
input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) | |
exp = g.op("Exp", input) | |
sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) | |
softmax = g.op("Div", exp, sum) | |
if dtype and dtype.node().kind() != "prim::Constant": | |
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
softmax = g.op( | |
"Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() | |
) | |
return softmax | |
def softplus(g: jit_utils.GraphContext, self, beta, threshold): | |
beta_const = symbolic_helper._maybe_get_const(beta, "f") | |
if beta_const != 1: | |
return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) | |
return g.op("Softplus", self) | |
def get_pool_ceil_padding(input, kernel_size, stride, padding): | |
# TODO(justinchuby): Looks like this op is deprecated in torch | |
sizes = symbolic_helper._get_tensor_sizes(input) | |
dim = sizes[-len(padding) :] if sizes is not None else None | |
if dim is None or any(i is None for i in dim): | |
return symbolic_helper._unimplemented( | |
"get_pool_ceil_padding", "input size not accessible", input | |
) | |
ceiled_output_dim = [ | |
int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) | |
+ 1 | |
for i in range(0, len(padding)) | |
] | |
# ensure last pooling starts inside | |
ceiled_output_dim = [ | |
ceiled_output_dim[i] - 1 | |
if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) | |
else ceiled_output_dim[i] | |
for i in range(0, len(ceiled_output_dim)) | |
] | |
padding_ceil = [ | |
0 | |
if (stride[i] == 1) | |
else ( | |
kernel_size[i] | |
- (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)) | |
) | |
for i in range(0, len(padding)) | |
] | |
# ensure padding is not > kernel_size | |
padding_ceil = [ | |
( | |
int(padding_ceil[i]) | |
if padding_ceil[i] < kernel_size[i] - 1 | |
else int(kernel_size[i] - 1) | |
) | |
if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) | |
else int(padding_ceil[i]) | |
for i in range(0, len(padding_ceil)) | |
] | |
return padding_ceil | |
def _max_pool(name, tuple_fn, ndims, return_indices): | |
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): | |
if set(tuple_fn(dilation)) != {1}: | |
return symbolic_helper._unimplemented(name, "dilation", input) | |
if not stride: | |
stride = kernel_size | |
padding = tuple(tuple_fn(padding)) | |
if ceil_mode: | |
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) | |
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) | |
else: | |
padding = padding * 2 | |
kwargs = { | |
"kernel_shape_i": tuple_fn(kernel_size), | |
"pads_i": padding, | |
"strides_i": tuple_fn(stride), | |
} | |
# easy but hacky way to get flattened indices values | |
# to be used to convert the indices values to non-flattened. | |
# In ONNX the indices are computed as a flatten 1-D tensor, | |
# so the values in indices are in [0, N x C x D1 x ... x Dn). | |
# To convert the indices to the same format used by Pytorch, | |
# we first execute a maxpool with a kernel and stride of 1 on the same input. | |
# This will result in a tensor of indices in which each index will have it's own value. | |
# Using this tensor as a reference, we extract the first index of each axis and subtract | |
# it from each index of this axis in the indices to convert. | |
# This step will result in a tensor were each dimension has values of indices within | |
# the dimension it is in. | |
# For more information : | |
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 | |
if return_indices: | |
r, indices = g.op("MaxPool", input, outputs=2, **kwargs) | |
_, flattened_indices = g.op( | |
"MaxPool", | |
input, | |
outputs=2, | |
kernel_shape_i=[1 for _ in range(ndims)], | |
strides_i=[1 for _ in range(ndims)], | |
) | |
# convert indices to have non-flattened indices values | |
s = symbolic_helper._slice_helper( | |
g, | |
flattened_indices, | |
axes=[2 + i for i in range(ndims)], | |
starts=list(tuple_fn(0)), | |
ends=list(tuple_fn(1)), | |
) | |
indices = sub(g, indices, s) | |
return r, indices | |
else: | |
r = g.op("MaxPool", input, outputs=1, **kwargs) | |
return r | |
return symbolic_fn | |
max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( | |
_max_pool( | |
"max_pool1d_with_indices", | |
torch.nn.modules.utils._single, | |
1, | |
return_indices=True, | |
) | |
) | |
max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( | |
_max_pool( | |
"max_pool2d_with_indices", | |
torch.nn.modules.utils._pair, | |
2, | |
return_indices=True, | |
) | |
) | |
max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( | |
_max_pool( | |
"max_pool3d_with_indices", | |
torch.nn.modules.utils._triple, | |
3, | |
return_indices=True, | |
) | |
) | |
def _avg_pool(name, tuple_fn): | |
def symbolic_fn( | |
g, | |
input: _C.Value, | |
kernel_size: Sequence[int], | |
stride: Sequence[int], | |
padding: Union[int, Sequence[int]], | |
ceil_mode: int, | |
count_include_pad: int, | |
divisor_override=None, | |
): | |
if not stride: | |
stride = kernel_size | |
padding = symbolic_helper._avgpool_helper( | |
tuple_fn, padding, kernel_size, stride, divisor_override, name | |
) | |
assert isinstance(padding, tuple) | |
adjusted_padding = padding | |
# Although onnx::AvgPool provides count_include_pad, | |
# The corner case of Average Pooling with ceil_mode on | |
# PyTorch allows sliding window go off bound, which leads to | |
# this accommodation. | |
# More detail on https://github.com/pytorch/pytorch/issues/57178 | |
if count_include_pad: | |
input = _op_with_optional_float_cast( | |
g, | |
"Pad", | |
input, | |
pads_i=((0,) * 2 + padding) * 2, | |
mode_s="constant", | |
value_f=0.0, | |
opset_before=11, | |
) | |
adjusted_padding = (0,) * len(padding) | |
if ceil_mode: | |
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) | |
adjusted_padding = adjusted_padding + tuple( | |
a + b for (a, b) in zip(padding_ceil, adjusted_padding) | |
) | |
else: | |
adjusted_padding = adjusted_padding * 2 | |
output = g.op( | |
"AveragePool", | |
input, | |
kernel_shape_i=tuple_fn(kernel_size), | |
strides_i=tuple_fn(stride), | |
pads_i=adjusted_padding, | |
) | |
return output | |
return symbolic_fn | |
def _adaptive_pool(name, type, tuple_fn, fn=None): | |
def symbolic_fn(g, input, output_size): | |
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions, | |
# by executing a GlobalPool. | |
# It is also supported for cases where the output size is a factor of the input size. | |
# For these cases the stride and kernel size are uniform along all the indices of | |
# the same dimension, which makes it possible to export it to ONNX. | |
# for MaxPool, GlobalMaxPool does not return indices, | |
# so we try using max_poolxd_with_indices, and if it is not possible | |
# (input is not a complete tensor or output size not factor of input size) | |
# then we call GlobalAveragePool and return None for the indices | |
output_size_value = output_size | |
try: | |
output_size = symbolic_helper._parse_arg(output_size, "is") | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
return symbolic_helper._onnx_unsupported( | |
"adaptive pooling, since output_size is not constant.", input | |
) | |
if output_size == [1] * len(output_size) and type == "AveragePool": | |
return g.op("GlobalAveragePool", input) | |
sizes = symbolic_helper._get_tensor_sizes(input) | |
try: | |
dim = sizes[2:] | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
dim = None | |
if dim is None or any(i is None for i in dim): | |
if output_size == [1] * len(output_size): | |
return g.op("GlobalMaxPool", input), None | |
return symbolic_helper._unimplemented( | |
name, "input size not accessible", input | |
) | |
# verify if output size % input size = 0 for all dim | |
mod = [dim[i] % output_size[i] for i in range(0, len(dim))] | |
if mod != [0] * len(mod): | |
if output_size == [1] * len(output_size): | |
return g.op("GlobalMaxPool", input), None | |
return symbolic_helper._unimplemented( | |
name, "output size that are not factor of input size", output_size_value | |
) | |
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] | |
# call max_poolxd_with_indices to get indices in the output | |
if type == "MaxPool": | |
return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) | |
output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) | |
return output | |
return symbolic_fn | |
def _prepare_onnx_paddings(dim: int, pad): | |
"""Generate paddings in ONNX order based on pad in pytorch. | |
Args: | |
dim: the dimension of the tensor. | |
pad: the paddings in pytorch. | |
The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... | |
""" | |
# The desired order of paddings is | |
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. | |
# n is the dimension of input. | |
# assume zero-dimensions in the beginning | |
paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) | |
# reverse order and collate first beginnings and then ends | |
paddings = paddings[-2::-2] + paddings[-1::-2] | |
return paddings | |
def _convert_padding_node(input): | |
padding = symbolic_helper._maybe_get_const(input, "is") | |
if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): | |
input_list = symbolic_helper._unpack_list(padding) | |
try: | |
padding = [ | |
symbolic_helper._get_const(v, "i", "padding") for v in input_list | |
] | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"Pad", 9, 11, "The sizes of the padding must be constant", input | |
) | |
return padding | |
def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): | |
mode = "constant" | |
try: | |
value = symbolic_helper._get_const(value, "f", "value") | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"Pad", 9, 11, "The value for the padding must be constant", value | |
) | |
padding = _convert_padding_node(padding) | |
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) | |
return _op_with_optional_float_cast( | |
g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 | |
) | |
def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): | |
padding = _convert_padding_node(pad) | |
assert len(padding) % 2 == 0 | |
ndim = len(padding) // 2 | |
cur = input | |
for idx in range(ndim): | |
pad_r = padding[-(2 * idx + 1)] | |
pad_l = padding[-(2 * idx + 2)] | |
tensors = [] | |
if pad_l > 0: | |
left = symbolic_helper._slice_helper( | |
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] | |
) | |
tensors.append(left) | |
if pad_l < 0 or pad_r < 0: | |
start = builtins.max(0, -pad_l) | |
end = -(builtins.max(0, -pad_r)) | |
middle = symbolic_helper._slice_helper( | |
g, | |
cur, | |
axes=[2 + idx], | |
starts=[start], | |
ends=[end], | |
) | |
tensors.append(middle) | |
else: | |
tensors.append(cur) | |
if pad_r > 0: | |
right = symbolic_helper._slice_helper( | |
g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] | |
) | |
tensors.append(right) | |
cur = g.op("Concat", *tensors, axis_i=(2 + idx)) | |
return cur | |
def reflection_pad(g: jit_utils.GraphContext, input, padding): | |
mode = "reflect" | |
padding = _convert_padding_node(padding) | |
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) | |
return _op_with_optional_float_cast( | |
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 | |
) | |
def replication_pad(g: jit_utils.GraphContext, input, padding): | |
mode = "edge" | |
padding = _convert_padding_node(padding) | |
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) | |
return _op_with_optional_float_cast( | |
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 | |
) | |
def pad( | |
g: jit_utils.GraphContext, | |
input: _C.Value, | |
pad: _C.Value, | |
mode: _C.Value, | |
value: _C.Value, | |
): | |
mode = symbolic_helper._parse_arg(mode, "s") | |
if mode == "replicate": | |
return replication_pad(g, input, pad) | |
elif mode == "reflect": | |
return reflection_pad(g, input, pad) | |
elif mode == "constant": | |
return constant_pad_nd(g, input, pad, value) | |
elif mode == "circular": | |
return _pad_circular(g, input, pad) | |
else: | |
raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) | |
def _interpolate(name: str, dim: int, interpolate_mode: str): | |
def symbolic_fn(g, input, output_size, *args): | |
scales, align_corners = symbolic_helper._get_interpolate_attributes( | |
g, interpolate_mode, args | |
) | |
symbolic_helper._interpolate_warning(interpolate_mode) | |
align_corners = symbolic_helper._maybe_get_scalar(align_corners) | |
if align_corners: | |
return symbolic_helper._unimplemented(name, "align_corners == True", input) | |
if scales is None: | |
scales = symbolic_helper._interpolate_size_to_scales( | |
g, input, output_size, dim | |
) | |
return g.op("Upsample", input, scales, mode_s=interpolate_mode) | |
return symbolic_fn | |
def __interpolate( | |
g: jit_utils.GraphContext, | |
input, | |
size, | |
scale_factor, | |
mode, | |
align_corners, | |
recompute_scale_factor, | |
antialias, | |
): | |
scales, mode = symbolic_helper._interpolate_get_scales_and_mode( | |
g, input, size, scale_factor, mode, align_corners | |
) | |
return g.op("Upsample", input, scales, mode_s=mode) | |
def bitwise_not(g: jit_utils.GraphContext, input): | |
if not symbolic_helper._is_bool(input): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise Not " | |
"for non-boolean input values", | |
input, | |
) | |
return g.op("Not", input) | |
def bitwise_or(g, self, other): | |
if not symbolic_helper._is_bool(self): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise OR " | |
"for non-boolean input values. self: ", | |
self, | |
) | |
if not symbolic_helper._is_bool(other): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise OR " | |
"for non-boolean input values. other: ", | |
other, | |
) | |
return g.op("Or", self, other) | |
def wrap_logical_op_with_cast_to(to_type): | |
def decorator(fn): | |
def wrap_with_cast(g, input, other): | |
to_cast_func = globals()[f"_cast_{to_type}"] | |
return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) | |
return wrap_with_cast | |
return decorator | |
def wrap_logical_op_with_negation(func: Callable) -> Callable: | |
def wrap_with_not(g, input, other): | |
return g.op("Not", func(g, input, other)) | |
return wrap_with_not | |
def __not_(g: jit_utils.GraphContext, self): | |
if not symbolic_helper._is_bool(self): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise Not " | |
"for non-boolean input values", | |
self, | |
) | |
return g.op("Not", self) | |
def eq(g: jit_utils.GraphContext, self, other): | |
if isinstance(self.type(), _C.DeviceObjType) and isinstance( | |
other.type(), _C.DeviceObjType | |
): | |
# ONNX doesn't have devices, so consider them all to be equal. | |
# The no-op check for equality will get constant-folded. | |
return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) | |
self_node = self.node() | |
other_node = other.node() | |
if self_node.kind() == other_node.kind() == "onnx::Constant": | |
if self_node.kindOf("value") == other_node.kindOf("value") == "s": | |
# Exporting strings to ONNX is not supported. | |
# If both strings are constant, we can compare them directly. | |
# The no-op check for equality will get constant-folded. | |
return g.op( | |
"Constant", | |
value_t=torch.tensor( | |
self_node.s("value") == other_node.s("value"), | |
dtype=torch.bool, | |
), | |
) | |
return g.op("Equal", self, other) | |
def ne(g: jit_utils.GraphContext, self, other): | |
return eq(g, self, other) | |
def gt(g: jit_utils.GraphContext, input, other): | |
return _gt_impl(g, input, other) | |
def _gt_impl(g: jit_utils.GraphContext, input, other): | |
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): | |
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) | |
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) | |
return g.op("Greater", input, other) | |
def lt(g: jit_utils.GraphContext, input, other): | |
return _lt_impl(g, input, other) | |
def _lt_impl(g: jit_utils.GraphContext, input, other): | |
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): | |
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) | |
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) | |
return g.op("Less", input, other) | |
def ge(g: jit_utils.GraphContext, input, other): | |
return _lt_impl(g, input, other) | |
def le(g: jit_utils.GraphContext, input, other): | |
return _gt_impl(g, input, other) | |
def __and_(g: jit_utils.GraphContext, input, other): | |
if not symbolic_helper._is_bool(input): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise AND " | |
"for non-boolean input values", | |
input, | |
) | |
if not symbolic_helper._is_bool(other): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise AND " | |
"for non-boolean input values", | |
other, | |
) | |
return g.op("And", input, other) | |
def __or_(g: jit_utils.GraphContext, input, other): | |
if not symbolic_helper._is_bool(input): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise OR " | |
"for non-boolean input values", | |
input, | |
) | |
if not symbolic_helper._is_bool(other): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise OR " | |
"for non-boolean input values", | |
other, | |
) | |
return g.op("Or", input, other) | |
def __xor_(g: jit_utils.GraphContext, input, other): | |
if not symbolic_helper._is_bool(input): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise XOR " | |
"for non-boolean input values", | |
input, | |
) | |
if not symbolic_helper._is_bool(other): | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting bitwise XOR " | |
"for non-boolean input values", | |
other, | |
) | |
return g.op("Xor", input, other) | |
def logical_and(g: jit_utils.GraphContext, input, other): | |
return g.op("And", input, other) | |
def logical_or(g: jit_utils.GraphContext, input, other): | |
return g.op("Or", input, other) | |
def logical_xor(g: jit_utils.GraphContext, input, other): | |
return g.op("Xor", input, other) | |
def logical_not(g: jit_utils.GraphContext, input): | |
return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) | |
def __rshift_(g: jit_utils.GraphContext, self, other): | |
# make sure to cast other to self's type | |
# (when self is long, make sure that other is not float) | |
self_scalar_type = _type_utils.JitScalarType.from_value(self) | |
if ( | |
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) | |
!= self_scalar_type | |
): | |
other = g.op( | |
"Cast", | |
other, | |
to_i=self_scalar_type.onnx_type(), | |
) | |
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) | |
# exponent (same type as self) has to be float or double in onnx::Pow | |
if not symbolic_helper._is_fp(self): | |
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
two_pow = g.op("Pow", two, other) | |
two_pow = g.op( | |
"Cast", | |
two_pow, | |
to_i=self_scalar_type.onnx_type(), | |
) | |
rshift = g.op("Div", self, two_pow) | |
return rshift | |
def __lshift_(g: jit_utils.GraphContext, self, other): | |
# make sure to cast other to self's type | |
# (when self is long, make sure that other is not float) | |
self_scalar_type = _type_utils.JitScalarType.from_value(self) | |
if ( | |
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) | |
!= self_scalar_type | |
): | |
other = g.op( | |
"Cast", | |
other, | |
to_i=self_scalar_type.onnx_type(), | |
) | |
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) | |
# exponent (same type as self) has to be float or double in onnx::Pow | |
if not symbolic_helper._is_fp(self): | |
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
two_pow = g.op("Pow", two, other) | |
two_pow = g.op( | |
"Cast", | |
two_pow, | |
to_i=self_scalar_type.onnx_type(), | |
) | |
lshift = g.op("Mul", self, two_pow) | |
return lshift | |
def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): | |
# Assumes that torch.where's first argument takes only Bool and Byte tensors. | |
if not symbolic_helper._is_bool(condition): | |
condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) | |
if self is None: | |
condition = nonzero(g, condition) | |
return symbolic_helper._unbind_helper( | |
g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs | |
) | |
return g.op("Where", condition, self, other) | |
def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): | |
# PyTorch dim and ONNX axis have different meanings. | |
# See Softmax comment for details. | |
# TODO: remove this as onnx opset 11 spec allows negative axes | |
input_dim = symbolic_helper._get_tensor_rank(input) | |
if input_dim is None: | |
return symbolic_helper._unimplemented( | |
"dim", | |
"ONNX and PyTorch use different strategies to split the input. " | |
"Input rank must be known at export time.", | |
) | |
if dim < 0: | |
dim = input_dim + dim | |
is_transpose_required = input_dim != dim + 1 | |
# ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. | |
if is_transpose_required: | |
axes = list(range(input_dim)) | |
axes[dim], axes[-1] = axes[-1], axes[dim] | |
input = g.op("Transpose", input, perm_i=axes) | |
dim = input_dim - 1 | |
return_op = g.op("LogSoftmax", input, axis_i=dim) | |
if dtype and dtype.node().kind() != "prim::Constant": | |
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
return_op = g.op( | |
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() | |
) | |
if is_transpose_required: | |
return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] | |
return return_op | |
def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): | |
if ( | |
half_to_float | |
and _type_utils.JitScalarType.from_value( | |
input, _type_utils.JitScalarType.UNDEFINED | |
) | |
== _type_utils.JitScalarType.HALF | |
): | |
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
return log_softmax(g, input, dim) | |
def _convolution( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
transposed, | |
output_padding, | |
groups, | |
benchmark, | |
deterministic, | |
cudnn_enabled, | |
allow_tf32=None, | |
): | |
weight_size = symbolic_helper._get_tensor_sizes(weight) | |
try: | |
kernel_shape = weight_size[2:] | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
kernel_shape = None | |
if kernel_shape is None or any(i is None for i in kernel_shape): | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of convolution for kernel of unknown shape.", | |
input, | |
) | |
args = [input, weight] | |
# ONNX only supports 1D bias | |
if ( | |
not symbolic_helper._is_none(bias) | |
and symbolic_helper._get_tensor_rank(bias) == 1 | |
): | |
args.append(bias) | |
kwargs = { | |
"kernel_shape_i": weight_size[2:], | |
"strides_i": stride, | |
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only | |
# symmetric padding | |
"pads_i": padding + padding, | |
"dilations_i": dilation, | |
"group_i": groups, | |
} | |
if any(o != 0 for o in output_padding): | |
# ONNX supports both output_shape and output_padding. they are equivalent expressive. | |
# output_padding is more straightforward, so we use it here. | |
# output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 | |
assert transposed | |
assert len(stride) == len(output_padding) | |
kwargs["output_padding_i"] = output_padding | |
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) | |
if ( | |
not symbolic_helper._is_none(bias) | |
and symbolic_helper._get_tensor_rank(bias) != 1 | |
): | |
return g.op("Add", n, bias) | |
else: | |
return n | |
def _convolution_mode( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
groups, | |
): | |
weight_size = symbolic_helper._get_tensor_sizes(weight) | |
try: | |
kernel_shape = weight_size[2:] | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
kernel_shape = None | |
if kernel_shape is None or any(i is None for i in kernel_shape): | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of convolution for kernel of unknown shape.", | |
input, | |
) | |
args = [input, weight] | |
# ONNX only supports 1D bias | |
if ( | |
not symbolic_helper._is_none(bias) | |
and symbolic_helper._get_tensor_rank(bias) == 1 | |
): | |
args.append(bias) | |
if padding == "valid": | |
padding = "VALID" | |
elif padding == "same": | |
padding = "SAME_UPPER" | |
kwargs = { | |
"kernel_shape_i": weight_size[2:], | |
"strides_i": stride, | |
"auto_pad_s": padding, | |
"dilations_i": dilation, | |
"group_i": groups, | |
} | |
n = g.op("Conv", *args, **kwargs) | |
if ( | |
not symbolic_helper._is_none(bias) | |
and symbolic_helper._get_tensor_rank(bias) != 1 | |
): | |
return g.op("Add", n, bias) | |
else: | |
return n | |
def convolution( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
transposed, | |
output_padding, | |
groups, | |
): | |
return _convolution( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
transposed, | |
output_padding, | |
groups, | |
None, | |
None, | |
None, | |
None, | |
) | |
def conv1d( | |
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups | |
): | |
str_padding = symbolic_helper._parse_arg(padding, "s") | |
if str_padding in ["valid", "same"]: | |
return _convolution_mode( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
str_padding, | |
dilation, | |
groups, | |
) | |
else: | |
padding = symbolic_helper._parse_arg(padding, "is") | |
return _convolution( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
False, | |
(), | |
groups, | |
None, | |
None, | |
None, | |
None, | |
) | |
def conv2d( | |
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups | |
): | |
str_padding = symbolic_helper._parse_arg(padding, "s") | |
if str_padding in ["valid", "same"]: | |
return _convolution_mode( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
str_padding, | |
dilation, | |
groups, | |
) | |
else: | |
padding = symbolic_helper._parse_arg(padding, "is") | |
return _convolution( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
False, | |
(), | |
groups, | |
None, | |
None, | |
None, | |
None, | |
) | |
def conv3d( | |
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups | |
): | |
str_padding = symbolic_helper._parse_arg(padding, "s") | |
if str_padding in ["valid", "same"]: | |
return _convolution_mode( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
str_padding, | |
dilation, | |
groups, | |
) | |
else: | |
padding = symbolic_helper._parse_arg(padding, "is") | |
return _convolution( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
False, | |
(), | |
groups, | |
None, | |
None, | |
None, | |
None, | |
) | |
def conv_transpose1d( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
output_padding, | |
groups, | |
dilation, | |
): | |
return _convolution( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
True, | |
output_padding, | |
groups, | |
None, | |
None, | |
None, | |
None, | |
) | |
def conv_transpose2d( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
output_padding, | |
groups, | |
dilation, | |
): | |
return _convolution( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
True, | |
output_padding, | |
groups, | |
None, | |
None, | |
None, | |
None, | |
) | |
def conv_transpose3d( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
output_padding, | |
groups, | |
dilation, | |
): | |
return _convolution( | |
g, | |
input, | |
weight, | |
bias, | |
stride, | |
padding, | |
dilation, | |
True, | |
output_padding, | |
groups, | |
None, | |
None, | |
None, | |
None, | |
) | |
def batch_norm( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
running_mean, | |
running_var, | |
training, | |
momentum, | |
eps, | |
cudnn_enabled, | |
): | |
symbolic_helper.check_training_mode(training, "batch_norm") | |
if ( | |
torch.is_autocast_enabled() | |
and not symbolic_helper.args_have_same_dtype( | |
[input, weight, bias, running_mean, running_var] | |
) | |
and GLOBALS.export_onnx_opset_version < 15 | |
): | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"BatchNormalization", | |
9, | |
15, | |
"All input tensors must have the same `dtype`." | |
" Turn off Autocast or export using opset version 15.", | |
input, | |
) | |
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( | |
g, input, weight, bias, running_mean, running_var | |
) | |
out = g.op( | |
"BatchNormalization", | |
input, | |
weight, | |
bias, | |
running_mean, | |
running_var, | |
epsilon_f=eps, | |
momentum_f=1 - momentum, | |
outputs=1 if not training else 5, | |
) | |
if not training: | |
return out | |
else: | |
res, new_running_mean, new_running_var, saved_mean, saved_var = out | |
new_running_mean.setType(running_mean.type()) | |
new_running_var.setType(running_var.type()) | |
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) | |
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) | |
return res | |
def native_layer_norm( | |
g: jit_utils.GraphContext, | |
input: _C.Value, | |
normalized_shape: Sequence[int], | |
weight: _C.Value, | |
bias: _C.Value, | |
eps: float, | |
) -> Tuple[_C.Value, _C.Value, _C.Value]: | |
axes = [-i for i in range(len(normalized_shape), 0, -1)] | |
two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) | |
eps_cst = symbolic_helper._generate_wrapped_number(g, eps) | |
mean = g.op("ReduceMean", input, axes_i=axes) | |
numerator = sub(g, input, mean) | |
# Cast it to eps dtype to avoid precision loss | |
is_type_half = ( | |
_type_utils.JitScalarType.from_value(numerator) | |
== _type_utils.JitScalarType.HALF | |
) | |
if is_type_half: | |
eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) | |
numerator = g.op( | |
"Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() | |
) | |
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula | |
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) | |
denominator = sqrt(g, g.op("Add", variance, eps_cst)) | |
normalized = g.op("Div", numerator, denominator) | |
# Cast back to input type as eps related ops are all done | |
if is_type_half: | |
input_dtype = _type_utils.JitScalarType.from_value(input) | |
normalized = g.op( | |
"Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() | |
) | |
if not (weight is None or symbolic_helper._is_none(weight)): | |
normalized = mul(g, normalized, weight) | |
if not (bias is None or symbolic_helper._is_none(bias)): | |
normalized = add(g, normalized, bias) | |
# rdenominator := 1 / sqrt(variance + eps) | |
# According to aten::native_layer_norm, rdenominator should have the same dtype as input, | |
# mean and normalized, so we need to Cast it back | |
if is_type_half: | |
denominator = g.op( | |
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined] | |
) | |
rdenominator = g.op("Reciprocal", denominator) | |
else: | |
rdenominator = reciprocal(g, denominator) | |
return normalized, mean, rdenominator | |
def layer_norm( | |
g: jit_utils.GraphContext, | |
input: _C.Value, | |
normalized_shape: Sequence[int], | |
weight: _C.Value, | |
bias: _C.Value, | |
eps: float, | |
cudnn_enable: bool, | |
) -> _C.Value: | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at( | |
"layer_norm", | |
input, | |
weight, | |
bias, | |
normalized_shape_i=normalized_shape, | |
eps_f=eps, | |
cudnn_enable_i=cudnn_enable, | |
) | |
normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) | |
return normalized | |
def instance_norm( | |
g: jit_utils.GraphContext, | |
input, | |
weight, | |
bias, | |
running_mean, | |
running_var, | |
use_input_stats: bool, | |
momentum: Number, | |
eps: Number, | |
cudnn_enabled: bool, | |
): | |
symbolic_helper.check_training_mode(use_input_stats, "instance_norm") | |
channel_size = symbolic_helper._get_tensor_dim_size(input, 1) | |
if weight is None or symbolic_helper._is_none(weight): | |
if channel_size is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of instance_norm for unknown channel size.", | |
input, | |
) | |
weight_value = torch.tensor( | |
[1.0] * channel_size, | |
dtype=_type_utils.JitScalarType.from_value(input).dtype(), | |
) | |
weight = g.op("Constant", value_t=weight_value) | |
if bias is None or symbolic_helper._is_none(bias): | |
if channel_size is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of instance_norm for unknown channel size.", | |
input, | |
) | |
bias_value = torch.tensor( | |
[0.0] * channel_size, | |
dtype=_type_utils.JitScalarType.from_value(input).dtype(), | |
) | |
bias = g.op("Constant", value_t=bias_value) | |
if ( | |
running_mean is None | |
or symbolic_helper._is_none(running_mean) | |
or running_var is None | |
or symbolic_helper._is_none(running_var) | |
): | |
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) | |
else: | |
input_size = symbolic_helper._get_tensor_sizes(input) | |
# If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. | |
# For more information instance_norm(): | |
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 | |
input_size_reshape = input_size.copy() | |
n = input_size[0] | |
if n is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of instance_norm training for unknown " | |
"batch size.", | |
input, | |
) | |
c = input_size[1] | |
input_size_reshape[0] = 1 | |
input_size_reshape[1] = n * c | |
weight_ = repeat( | |
g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) | |
) | |
bias_ = repeat( | |
g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) | |
) | |
running_mean_ = repeat( | |
g, | |
running_mean, | |
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), | |
) | |
running_var_ = repeat( | |
g, | |
running_var, | |
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), | |
) | |
input_reshaped = g.op( | |
"Reshape", | |
input, | |
g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), | |
) | |
out = batch_norm( | |
g, | |
input_reshaped, | |
weight_, | |
bias_, | |
running_mean_, | |
running_var_, | |
use_input_stats, | |
momentum, | |
eps, | |
cudnn_enabled, | |
) | |
return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) | |
def unfold(g: jit_utils.GraphContext, input, dimension, size, step): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) | |
sizes = symbolic_helper._get_tensor_sizes(input) | |
# FIXME(justinchuby): Get rid of the try catch here to improve readability | |
try: | |
sizedim = sizes[dimension] | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
sizedim = None | |
if sizedim is not None: | |
low_indices = range(0, sizedim, step) | |
hi_indices = range(size, sizedim + 1, step) | |
stack = [ | |
symbolic_helper._slice_helper( | |
g, input, axes=[dimension], starts=[low], ends=[hi] | |
) | |
for low, hi in zip(low_indices, hi_indices) | |
] | |
ndim = len(sizes) | |
perm = list(range(0, ndim)) | |
perm.append(perm.pop(dimension)) | |
unsqueeze = [ | |
symbolic_helper._unsqueeze_helper( | |
g, g.op("Transpose", t, perm_i=perm), [dimension] | |
) | |
for t in stack | |
] | |
return g.op("Concat", *unsqueeze, axis_i=dimension) | |
else: | |
return symbolic_helper._unimplemented( | |
"Unfold", "input size not accessible", input | |
) | |
def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): | |
if scale and scale != 1.0: | |
return symbolic_helper._unimplemented( | |
"scale", "does not support scale in Elu", scale | |
) | |
if input_scale and input_scale != 1.0: | |
return symbolic_helper._unimplemented( | |
"input_scale", "does not support input_scale in Elu", input_scale | |
) | |
# See Note [Export inplace] | |
return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) | |
def selu(g: jit_utils.GraphContext, input): | |
return g.op("Selu", input) | |
def index_select(g: jit_utils.GraphContext, self, dim, index): | |
# In case of a scalar index, index_select returns a tensor with the same rank as the input. | |
# To match this behavior in ONNX, we make index a 1D tensor so that the following gather | |
# also produces a tensor with the same rank as the input. | |
return symbolic_helper._select_helper(g, self, dim, index) | |
def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): | |
if symbolic_helper._is_packed_list(indices_list_value): | |
indices_list = symbolic_helper._unpack_list(indices_list_value) | |
else: | |
indices_list = [indices_list_value] | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
args = [self] + indices_list + [values, accumulate] | |
return g.at("index_put", *args) | |
accumulate = symbolic_helper._parse_arg(accumulate, "b") | |
if len(indices_list) == 0: | |
if accumulate: | |
return add(g, self, values) | |
return values | |
symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) | |
def index_fill(g: jit_utils.GraphContext, self, dim, index, value): | |
dim_value = symbolic_helper._parse_arg(dim, "i") | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at( | |
"index_fill", | |
self, | |
index, | |
value, | |
overload_name="int_Scalar", | |
dim_i=dim_value, | |
) | |
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( | |
g, self, dim, index | |
) | |
value = symbolic_helper._maybe_get_scalar(value) | |
value = symbolic_helper._if_scalar_type_as(value, self) | |
expanded_value = expand(g, value, expanded_index_shape, None) | |
return scatter(g, self, dim, expanded_index, expanded_value) | |
def index_copy(g: jit_utils.GraphContext, self, dim, index, source): | |
dim_value = symbolic_helper._parse_arg(dim, "i") | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("index_copy", self, index, source, dim_i=dim_value) | |
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( | |
g, self, dim, index | |
) | |
return scatter(g, self, dim, expanded_index, source) | |
def bucketize( | |
g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False | |
): | |
out_type = _C_onnx.TensorProtoDataType.INT64 | |
if out_int32: | |
out_type = _C_onnx.TensorProtoDataType.INT32 | |
# A tensor expanded_boundaries is created such that it | |
# contains a copy of boundaries for each element of self. | |
new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) | |
# Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops | |
# https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md | |
tensor_rank = symbolic_helper._get_tensor_rank(self) | |
assert tensor_rank is not None | |
unsqueeze_axes = list(range(1, tensor_rank + 1)) | |
expanded_boundaries = expand( | |
g, | |
symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), | |
new_shape, | |
None, | |
) | |
# Compare each element of self to boundaries to get a tensor | |
# with leading 1s and trailing 0s. | |
# e.g., 4 > [1, 3, 4] = [1, 1, 0] | |
# The index of the last 1 is the bucket where the element should go. | |
if right: | |
cond = ge(g, self, expanded_boundaries) | |
else: | |
cond = gt(g, self, expanded_boundaries) | |
cond_out = g.op("Cast", cond, to_i=out_type) | |
# Sum to get the number of 1s corresponding to each element, | |
# which is the same as the bucket index. | |
# e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 | |
return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) | |
def type_as(g: jit_utils.GraphContext, self, other): | |
self_dtype = symbolic_helper._try_get_scalar_type(self) | |
other_dtype = symbolic_helper._try_get_scalar_type(other) | |
if self_dtype == other_dtype and self_dtype is not None: | |
return self | |
if other_dtype is not None: | |
return g.op( | |
"Cast", | |
self, | |
to_i=other_dtype.onnx_type(), | |
) | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
# We don't know the type of other, bail by emitting ATen | |
return g.at("type_as", self, other) | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of type_as for tensor " | |
"of unknown dtype. Please check if the dtype of the " | |
"parameter passed to the type_as function is correct.", | |
other, | |
) | |
def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) | |
cross = symbolic_helper._reducesum_helper( | |
g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 | |
) | |
x1_l2 = symbolic_helper._reducesum_helper( | |
g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 | |
) | |
x2_l2 = symbolic_helper._reducesum_helper( | |
g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 | |
) | |
div_tens = max( | |
g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) | |
) | |
return div(g, cross, div_tens) | |
def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): | |
if not symbolic_helper._is_value(eps): | |
eps = g.op("Constant", value_t=torch.tensor([eps])) | |
inv_p = div( | |
g, | |
g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), | |
add(g, p, eps), | |
) | |
summation = symbolic_helper._reducesum_helper( | |
g, | |
pow(g, sub(g, input1, input2), p), | |
axes_i=[-1], | |
keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), | |
) | |
return pow(g, summation, inv_p) | |
# ignore clone operators that are inserted by PyTorch autograd | |
def clone(g: jit_utils.GraphContext, input, unused_memory_format): | |
return input | |
def abs(g: jit_utils.GraphContext, self): | |
return g.op("Abs", self) | |
def log(g: jit_utils.GraphContext, self): | |
return g.op("Log", self) | |
def log1p(g: jit_utils.GraphContext, self): | |
return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) | |
def log10(g: jit_utils.GraphContext, self): | |
_ln10 = 2.30258509299404568401 | |
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) | |
def pow(g: jit_utils.GraphContext, self, exponent): | |
f_dtype = _type_utils.JitScalarType.from_value(self) | |
if not symbolic_helper._is_fp(self): | |
f_dtype = _type_utils.JitScalarType.FLOAT | |
self = g.op("Cast", self, to_i=f_dtype.onnx_type()) | |
if not symbolic_helper._is_fp(exponent): | |
exponent = g.op( | |
"Cast", | |
exponent, | |
to_i=f_dtype.onnx_type(), | |
) | |
pow = g.op("Pow", self, exponent) | |
return pow | |
def clamp(g: jit_utils.GraphContext, self, min, max): | |
# min or max may be None that we need to dispatch to | |
# Clip separately, as ONNX does not have None syntax | |
if symbolic_helper._is_none(min): | |
return clamp_max(g, self, max) | |
elif symbolic_helper._is_none(max): | |
return clamp_min(g, self, min) | |
else: | |
if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): | |
return _op_with_optional_float_cast( | |
g, | |
"Clip", | |
self, | |
min_f=symbolic_helper._parse_arg(min, "f"), | |
max_f=symbolic_helper._parse_arg(max, "f"), | |
opset_before=12, | |
) | |
else: | |
return clamp_max(g, clamp_min(g, self, min), max) | |
def clamp_min(g: jit_utils.GraphContext, self, min): | |
if symbolic_helper._is_constant(min): | |
return _op_with_optional_float_cast( | |
g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 | |
) | |
else: | |
dtype = _type_utils.JitScalarType.from_value(self) | |
min = g.op("Cast", min, to_i=dtype.onnx_type()) | |
return _op_with_optional_float_cast(g, "Max", self, min, opset_before=12) | |
def clamp_max(g: jit_utils.GraphContext, self, max): | |
if symbolic_helper._is_constant(max): | |
return _op_with_optional_float_cast( | |
g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 | |
) | |
else: | |
dtype = _type_utils.JitScalarType.from_value(self) | |
max = g.op("Cast", max, to_i=dtype.onnx_type()) | |
return _op_with_optional_float_cast(g, "Min", self, max, opset_before=12) | |
# torch.max (same for torch.min) actually has two interfaces smashed together: | |
# torch.max(x, dim, keepdim) and torch.max(x, y) | |
# TODO(justinchuby): Support multiple quantized args in output | |
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): | |
# torch.max(input) | |
if dim_or_y is None and keepdim is None: | |
return g.op("ReduceMax", self, keepdims_i=0) | |
# torch.max(input, other) | |
if keepdim is None: | |
return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) | |
# torch.max(input, dim, keepdim) | |
else: | |
dim = symbolic_helper._get_const(dim_or_y, "i", "dim") | |
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") | |
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) | |
indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) | |
return max, indices | |
def maximum(g: jit_utils.GraphContext, input, other): | |
return max(g, input, dim_or_y=other) | |
# TODO(justinchuby): Support multiple quantized args in output | |
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): | |
# torch.min(input) | |
if dim_or_y is None and keepdim is None: | |
return g.op("ReduceMin", self, keepdims_i=0) | |
# torch.min(input, other) | |
if keepdim is None: | |
return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) | |
# torch.min(input, dim, keepdim) | |
else: | |
dim = symbolic_helper._get_const(dim_or_y, "i", "dim") | |
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") | |
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) | |
indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) | |
return min, indices | |
def minimum(g: jit_utils.GraphContext, input, other): | |
return min(g, input, dim_or_y=other) | |
def amax(g: jit_utils.GraphContext, self, dim, keepdim): | |
return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) | |
def amin(g: jit_utils.GraphContext, self, dim, keepdim): | |
return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) | |
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): | |
reduce_kwargs = {"keepdims_i": keepdim} | |
if not symbolic_helper._is_none(dim): | |
dim = symbolic_helper._get_const(dim, "i", "dim") | |
reduce_kwargs["axes_i"] = [dim] | |
return g.op("ReduceMin", self, **reduce_kwargs), g.op( | |
"ReduceMax", self, **reduce_kwargs | |
) | |
def exp(g: jit_utils.GraphContext, self): | |
return g.op("Exp", self) | |
def dropout(g: jit_utils.GraphContext, input, p, train): | |
symbolic_helper.check_training_mode(train, "dropout") | |
# if train is False, dropout is no-op | |
if not train: | |
return input | |
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) | |
return r | |
# See Note [Export inplace] | |
def _unsupported_dropout(name: str): | |
def feature_dropout(g, input, p, train): | |
# NB: In inference mode, FeatureDropout is exported as an identity op. | |
if train: | |
return symbolic_helper._unimplemented(name, "training mode", input) | |
return input | |
return feature_dropout | |
def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): | |
if p == 1: | |
f = _reduce_op_symbolic("ReduceL1") | |
elif p == 2: | |
f = _reduce_op_symbolic("ReduceL2") | |
else: | |
raise errors.SymbolicValueError( | |
"ONNX export only p-norms with p of 1 or 2", self | |
) | |
result = f(g, self, dim=dim, keepdim=keepdim) | |
if dtype is not None: | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
return result | |
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("conv_tbc", input, weight, bias, pad_i=pad) | |
else: | |
# input must have 3 dimensions, see: | |
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 | |
# input = (time, batch, in_channels) | |
# weight = (kernel_width, in_channels, out_channels) | |
# bias = (out_channels,) | |
input = g.op("Transpose", input, perm_i=[1, 2, 0]) | |
weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) | |
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) | |
return g.op("Transpose", conv, perm_i=[2, 0, 1]) | |
def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at( | |
"_unique", | |
input, | |
sorted_i=sorted, | |
return_inverse_i=return_inverse, | |
outputs=2, | |
) | |
else: | |
return symbolic_helper._onnx_unsupported("_unique", input) | |
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at( | |
"_unique2", | |
input, | |
sorted_i=sorted, | |
return_inverse_i=return_inverse, | |
return_counts_i=return_counts, | |
outputs=3, | |
) | |
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) | |
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) | |
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) | |
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) | |
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) | |
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) | |
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) | |
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) | |
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): | |
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) | |
def empty( | |
g: jit_utils.GraphContext, | |
sizes, | |
dtype, | |
layout, | |
device, | |
pin_memory=False, | |
memory_format=None, | |
): | |
return zeros(g, sizes, dtype, layout, device, pin_memory) | |
def empty_like( | |
g: jit_utils.GraphContext, | |
input, | |
dtype=None, | |
layout=None, | |
device=None, | |
pin_memory=False, | |
memory_format=None, | |
): | |
return zeros_like(g, input, dtype, layout, device, pin_memory) | |
def new_empty( | |
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False | |
): | |
self_dtype = symbolic_helper._try_get_scalar_type(self) | |
if symbolic_helper._is_none(dtype) and self_dtype is not None: | |
dtype = self_dtype | |
return empty(g, sizes, dtype, layout, device, pin_memory) | |
def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if dtype is None: | |
dtype = _type_utils.JitScalarType.FLOAT | |
scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
return scalar | |
def tensor( | |
g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False | |
): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if symbolic_helper._is_packed_list(data): | |
if dtype is None: | |
dtype = _type_utils.JitScalarType.from_value( | |
symbolic_helper._unpack_list(data)[0] | |
) | |
input_list = list() | |
for t in symbolic_helper._unpack_list(data): | |
shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) | |
t = symbolic_helper._reshape_helper(g, t, shape_reference) | |
t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
input_list.append(t) | |
return g.op("Concat", *input_list, axis_i=0) | |
else: | |
if dtype is None: | |
dtype = _type_utils.JitScalarType.from_value(data) | |
if symbolic_helper._is_list(data) and ( | |
symbolic_helper._is_tensor_list(data) | |
or symbolic_helper._is_scalar_list(data) | |
): | |
data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) | |
return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): | |
return tensor(g, data, dtype, device) | |
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): | |
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.FLOAT | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
sizes_ = symbolic_helper._maybe_get_const(sizes, "is") | |
if isinstance(sizes_, list) and len(sizes_) == 0: | |
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) | |
return g.op( | |
"ConstantOfShape", | |
sizes, | |
value_t=torch.tensor([0], dtype=scalar_type.dtype()), | |
) | |
def zeros_like( | |
g: jit_utils.GraphContext, | |
input, | |
dtype=None, | |
layout=None, | |
device=None, | |
pin_memory=False, | |
memory_format=None, | |
): | |
shape = g.op("Shape", input) | |
if symbolic_helper._is_none(dtype): | |
scalar_type = _type_utils.JitScalarType.from_value( | |
input, _type_utils.JitScalarType.FLOAT | |
) | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
return g.op( | |
"ConstantOfShape", | |
shape, | |
value_t=torch.tensor([0], dtype=scalar_type.dtype()), | |
) | |
def new_zeros( | |
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False | |
): | |
self_dtype = symbolic_helper._try_get_scalar_type(self) | |
if symbolic_helper._is_none(dtype) and self_dtype is not None: | |
dtype = self_dtype | |
return zeros(g, sizes, dtype, layout, device, pin_memory) | |
def zero(g: jit_utils.GraphContext, self): | |
self_dtype = symbolic_helper._try_get_scalar_type(self) | |
return zeros_like(g, self, self_dtype) | |
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.FLOAT | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
sizes_ = symbolic_helper._maybe_get_const(sizes, "is") | |
if isinstance(sizes_, list) and len(sizes_) == 0: | |
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) | |
return g.op( | |
"ConstantOfShape", | |
sizes, | |
value_t=torch.tensor([1], dtype=scalar_type.dtype()), | |
) | |
def ones_like( | |
g: jit_utils.GraphContext, | |
input, | |
dtype=None, | |
layout=None, | |
device=None, | |
pin_memory=False, | |
memory_format=None, | |
): | |
shape = g.op("Shape", input) | |
if symbolic_helper._is_none(dtype): | |
scalar_type = _type_utils.JitScalarType.from_value( | |
input, _type_utils.JitScalarType.FLOAT | |
) | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
return g.op( | |
"ConstantOfShape", | |
shape, | |
value_t=torch.tensor([1], dtype=scalar_type.dtype()), | |
) | |
def new_ones( | |
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False | |
): | |
self_dtype = symbolic_helper._try_get_scalar_type(self) | |
if symbolic_helper._is_none(dtype) and self_dtype is not None: | |
dtype = self_dtype | |
return ones(g, sizes, dtype, layout, device, pin_memory) | |
def full( | |
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False | |
): | |
const_value = symbolic_helper._maybe_get_const(value, "t") | |
if symbolic_helper._is_value(const_value): | |
dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype | |
tmp = zeros(g, sizes, dtype, layout, device) | |
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) | |
else: | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.FLOAT | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
sizes_ = symbolic_helper._maybe_get_const(sizes, "is") | |
if isinstance(sizes_, list) and len(sizes_) == 0: | |
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) | |
return g.op( | |
"ConstantOfShape", | |
sizes, | |
value_t=const_value.view(1).to(scalar_type.dtype()), | |
) | |
def full_like( | |
g: jit_utils.GraphContext, | |
input, | |
fill_value, | |
dtype=None, | |
layout=None, | |
device=None, | |
pin_memory=False, | |
memory_format=None, | |
): | |
fill_value = symbolic_helper._maybe_get_const(fill_value, "f") | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.from_value( | |
input, _type_utils.JitScalarType.FLOAT | |
) | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
if symbolic_helper._is_value(fill_value): | |
tmp = zeros_like(g, input, dtype, layout, device) | |
fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) | |
return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) | |
else: | |
shape = g.op("Shape", input) | |
return g.op( | |
"ConstantOfShape", | |
shape, | |
value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), | |
) | |
def new_full( | |
g: jit_utils.GraphContext, | |
self, | |
size, | |
fill_value, | |
dtype, | |
layout, | |
device, | |
pin_memory=False, | |
): | |
self_dtype = symbolic_helper._try_get_scalar_type(self) | |
if symbolic_helper._is_none(dtype) and self_dtype is not None: | |
dtype = self_dtype | |
return full(g, size, fill_value, dtype, layout, device, pin_memory) | |
def eye(g: jit_utils.GraphContext, *args): | |
if len(args) == 5: | |
# aten::eye(n, dtype, layout, device, pin_memory) | |
n, dtype, layout, device, pin_memory = args | |
dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) | |
shape = g.op("Concat", dim_size, dim_size, axis_i=0) | |
tensor = zeros(g, shape, dtype, layout, device) | |
return g.op("EyeLike", tensor) | |
if len(args) == 6: | |
# aten::eye(n, m, dtype, layout, device, pin_memory) | |
n, m, dtype, layout, device, pin_memory = args | |
shape = g.op( | |
"Concat", | |
symbolic_helper._unsqueeze_helper(g, n, [0]), | |
symbolic_helper._unsqueeze_helper(g, m, [0]), | |
axis_i=0, | |
) | |
tensor = zeros(g, shape, dtype, layout, device) | |
return g.op("EyeLike", tensor) | |
return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") | |
def slice(g: jit_utils.GraphContext, self, *args): | |
if len(args) == 4: | |
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor | |
dim, start, end, step = args | |
step = symbolic_helper._parse_arg(step, "i") | |
if step != 1: | |
raise errors.SymbolicValueError("step!=1 is currently not supported", self) | |
is_start_none = start.node().kind() == "prim::Constant" and isinstance( | |
start.type(), _C.NoneType | |
) | |
is_end_none = end.node().kind() == "prim::Constant" and isinstance( | |
end.type(), _C.NoneType | |
) | |
is_start_onnx_const = start.node().kind() == "onnx::Constant" | |
is_end_onnx_const = end.node().kind() == "onnx::Constant" | |
if ( | |
((not is_start_none) and (not is_start_onnx_const)) | |
or ((not is_end_none) and (not is_end_onnx_const)) | |
or dim.node().kind() != "onnx::Constant" | |
): | |
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " | |
"is a deprecated experimental op. Please use statically allocated " | |
"variables or export to a higher opset version.", | |
self, | |
) | |
else: | |
start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) | |
end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) | |
dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) | |
return g.op( | |
"DynamicSlice", | |
self, | |
start_unsqueezed, | |
end_unsqueezed, | |
dim_unsqueezed, | |
) | |
else: | |
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") | |
end = ( | |
_constants.INT64_MAX | |
if is_end_none | |
else symbolic_helper._parse_arg(end, "i") | |
) | |
dim = symbolic_helper._parse_arg(dim, "i") | |
return symbolic_helper._slice_helper( | |
g, self, axes=[dim], starts=[start], ends=[end] | |
) | |
elif len(args) == 3: | |
# aten::slice(t[] l, int start, int end, int step) -> t[] | |
start, end, step = args | |
dim = 0 | |
is_start_none = start.node().kind() == "prim::Constant" and isinstance( | |
start.type(), _C.NoneType | |
) | |
is_end_none = end.node().kind() == "prim::Constant" and isinstance( | |
end.type(), _C.NoneType | |
) | |
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") | |
end = ( | |
_constants.INT64_MAX | |
if is_end_none | |
else symbolic_helper._parse_arg(end, "i") | |
) | |
return symbolic_helper._slice_helper( | |
g, self, axes=[dim], starts=[start], ends=[end] | |
) | |
return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") | |
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): | |
return _op_with_optional_float_cast( | |
g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 | |
) | |
def hardswish(g: jit_utils.GraphContext, self): | |
hs = hardsigmoid(g, self) | |
return g.op("Mul", self, hs) | |
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp | |
def hardsigmoid(g: jit_utils.GraphContext, self): | |
# Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. | |
# See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html | |
return g.op("HardSigmoid", self, alpha_f=1 / 6) | |
def tanhshrink(g: jit_utils.GraphContext, self): | |
return g.op("Sub", self, tanh(g, self)) | |
def hardshrink(g: jit_utils.GraphContext, self, lambd): | |
scalar_type = _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.FLOAT | |
) | |
lambd_op = g.op( | |
"Constant", | |
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), | |
) | |
cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) | |
return g.op( | |
"Where", | |
cond, | |
self, | |
g.op( | |
"Constant", | |
value_t=torch.tensor(0, dtype=scalar_type.dtype()), | |
), | |
) | |
def softshrink(g: jit_utils.GraphContext, self, lambd): | |
scalar_type = _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.FLOAT | |
) | |
lambd_op = g.op( | |
"Constant", | |
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), | |
) | |
gt_cond = gt(g, self, lambd_op) | |
gt_out = g.op( | |
"Where", | |
gt_cond, | |
sub(g, self, lambd_op), | |
g.op( | |
"Constant", | |
value_t=torch.tensor(0, dtype=scalar_type.dtype()), | |
), | |
) | |
lt_cond = lt(g, self, neg(g, lambd_op)) | |
lt_out = g.op( | |
"Where", | |
lt_cond, | |
add(g, self, lambd_op), | |
g.op( | |
"Constant", | |
value_t=torch.tensor(0, dtype=scalar_type.dtype()), | |
), | |
) | |
return add(g, gt_out, lt_out) | |
def alias(g: jit_utils.GraphContext, self): | |
return self | |
def unsqueeze(g: jit_utils.GraphContext, self, dim): | |
# Handle negative dim | |
if dim < 0: | |
rank = symbolic_helper._get_tensor_rank(self) | |
if rank is not None: | |
warnings.warn( | |
"ONNX export unsqueeze with negative axis " | |
+ str(dim) | |
+ " might cause the onnx model to be incorrect. " | |
+ "Negative axis is not supported in ONNX. " | |
+ "Axis is converted to " | |
+ str(dim + rank + 1) | |
+ " based on input shape at export time. " | |
+ "Passing an tensor of different rank in execution will be incorrect." | |
) | |
dim = dim + rank + 1 | |
else: | |
return symbolic_helper._unimplemented( | |
"unsqueeze", "negative axis with unknown input rank", self | |
) | |
return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) | |
# TODO(justinchuby): Support multiple quantized args in output | |
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): | |
if out is not None: | |
symbolic_helper._unimplemented( | |
"Sort", "Out parameter is not supported for sort", self | |
) | |
self_sizes = symbolic_helper._get_tensor_sizes(self) | |
try: | |
dim_size = self_sizes[dim] | |
except Exception: | |
# FIXME(justinchuby): Avoid catching Exception. | |
# Catch a more specific exception instead. | |
dim_size = None | |
if dim_size is None: | |
return symbolic_helper._unimplemented("Sort", "input size not accessible", self) | |
return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) | |
def numel(g: jit_utils.GraphContext, self): | |
shape = g.op("Shape", self) | |
return g.op("ReduceProd", shape, keepdims_i=0) | |
# TODO(justinchuby): Support multiple quantized args in output | |
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): | |
if out is not None: | |
symbolic_helper._unimplemented( | |
"TopK", "Out parameter is not supported for topk", self | |
) | |
if not largest: | |
symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) | |
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) | |
def convert_element_type(g: jit_utils.GraphContext, self, *args): | |
dtype = symbolic_helper._get_const(args[0], "i", "dtype") | |
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
def to(g: jit_utils.GraphContext, self, *args): | |
def is_aten_to_device_only(args): | |
if len(args) == 4: | |
# aten::to(Tensor, Device, bool, bool, memory_format) | |
return ( | |
args[0].node().kind() == "prim::device" | |
or args[0].type().isSubtypeOf(_C.ListType.ofInts()) | |
or isinstance(args[0].type(), _C.DeviceObjType) | |
) | |
elif len(args) == 5: | |
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) | |
# When dtype is None, this is a aten::to(device) call | |
dtype = symbolic_helper._get_const(args[1], "i", "dtype") | |
return dtype is None | |
elif len(args) in (6, 7): | |
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor | |
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor | |
# When dtype is None, this is a aten::to(device) call | |
dtype = symbolic_helper._get_const(args[0], "i", "dtype") | |
return dtype is None | |
return False | |
# ONNX doesn't have a concept of a device, so we ignore device-only casts | |
if is_aten_to_device_only(args): | |
return self | |
if len(args) == 4: | |
# TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]() | |
# In this case, the constant value is a tensor not int, | |
# so symbolic_helper._maybe_get_const(args[0], 'i') would not work. | |
dtype = args[0] | |
if ( | |
symbolic_helper._is_value(args[0]) | |
and args[0].node().kind() == "onnx::Constant" | |
): | |
tval = symbolic_helper._node_get(args[0].node(), "value") | |
if isinstance(tval, torch.Tensor): | |
if len(tval.shape) == 0: | |
tval = tval.item() | |
dtype = int(tval) | |
else: | |
dtype = tval | |
if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): | |
# aten::to(Tensor, Tensor, bool, bool, memory_format) | |
dtype = _type_utils.JitScalarType.from_value(args[0]) | |
return g.op( | |
"Cast", | |
self, | |
to_i=dtype.onnx_type(), | |
) | |
else: | |
# aten::to(Tensor, ScalarType, bool, bool, memory_format) | |
# memory_format is ignored | |
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
elif len(args) == 5: | |
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) | |
dtype = symbolic_helper._get_const(args[1], "i", "dtype") | |
# memory_format is ignored | |
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
elif len(args) == 6: | |
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor | |
dtype = symbolic_helper._get_const(args[0], "i", "dtype") | |
# Layout, device and memory_format are ignored | |
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
elif len(args) == 7: | |
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor | |
dtype = symbolic_helper._get_const(args[0], "i", "dtype") | |
# Layout, device and memory_format are ignored | |
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) | |
return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) | |
def repeat(g: jit_utils.GraphContext, self, repeats): | |
dtype = _type_utils.JitScalarType.INT64 | |
shape_ = ones_like(g, repeats, dtype) | |
self = g.op("Expand", self, shape_) | |
return g.op("Tile", self, repeats) | |
def repeat_interleave( | |
g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None | |
): | |
repeats_dim = symbolic_helper._get_tensor_rank(repeats) | |
repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) | |
input_sizes = symbolic_helper._get_tensor_sizes(self) | |
if repeats_dim is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", | |
self, | |
) | |
if repeats_sizes is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of repeat_interleave for unknown repeats size.", | |
self, | |
) | |
if input_sizes is None: | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of repeat_interleave for unknown input size.", | |
self, | |
) | |
# if dim is None flatten | |
# By default, use the flattened input array, and return a flat output array | |
if symbolic_helper._is_none(dim): | |
self = symbolic_helper._reshape_helper( | |
g, self, g.op("Constant", value_t=torch.tensor([-1])) | |
) | |
dim = torch.tensor(0, dtype=torch.int64) | |
else: | |
dim = symbolic_helper._maybe_get_scalar(dim) | |
# Handle cases where dim is negative | |
if dim < 0: | |
dim += len(input_sizes) | |
input_sizes_temp = input_sizes.copy() | |
for idx, input_size in enumerate(input_sizes): | |
if input_size is None: | |
input_sizes[idx], input_sizes_temp[idx] = 0, -1 | |
# Cases where repeats is an int or single value tensor | |
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): | |
if input_sizes[dim] == 0: | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"repeat_interleave", | |
9, | |
13, | |
"Unsupported along dimension with unknown input size", | |
self, | |
) | |
return symbolic_helper._repeat_interleave_single_value_repeat_helper( | |
g, self, repeats, dim | |
) | |
# Cases where repeats is a 1 dim Tensor | |
elif repeats_dim == 1: | |
if input_sizes[dim] == 0: | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"repeat_interleave", | |
9, | |
13, | |
"Unsupported along dimension with unknown input size", | |
self, | |
) | |
if repeats_sizes[0] is None: | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"repeat_interleave", | |
9, | |
13, | |
"Unsupported for cases with dynamic repeats", | |
self, | |
) | |
assert ( | |
repeats_sizes[0] == input_sizes[dim] | |
), "repeats must have the same size as input along dim" | |
reps = repeats_sizes[0] | |
else: | |
raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) | |
final_splits = list() | |
r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) | |
i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) | |
input_sizes[dim], input_sizes_temp[dim] = -1, 1 | |
for idx, r_split in enumerate(r_splits): | |
i_split = unsqueeze(g, i_splits[idx], dim + 1) | |
r_concat = [ | |
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), | |
r_split, | |
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), | |
] | |
r_concat = g.op("Concat", *r_concat, axis_i=0) | |
i_split = expand(g, i_split, r_concat, None) | |
i_split = symbolic_helper._reshape_helper( | |
g, | |
i_split, | |
g.op("Constant", value_t=torch.LongTensor(input_sizes)), | |
allowzero=0, | |
) | |
final_splits.append(i_split) | |
return g.op("Concat", *final_splits, axis_i=dim) | |
def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): | |
dims = symbolic_helper._get_tensor_sizes(self) | |
if len(dims) != 4: | |
return symbolic_helper._unimplemented( | |
"pixel_shuffle", "only support 4d input", self | |
) | |
if any(i is None for i in dims[1:]): | |
after_view = symbolic_helper._reshape_helper( | |
g, | |
symbolic_helper._unsqueeze_helper(g, self, [2, 3]), | |
g.op( | |
"Constant", | |
value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), | |
), | |
allowzero=0, | |
) | |
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) | |
# For dynamic input shapes, two reshapes are performed | |
reshape_h = symbolic_helper._reshape_helper( | |
g, | |
after_transpose, | |
g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), | |
allowzero=0, | |
) | |
reshape_w = symbolic_helper._reshape_helper( | |
g, | |
reshape_h, | |
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), | |
allowzero=0, | |
) | |
return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) | |
else: | |
output_channel = dims[1] // upscale_factor // upscale_factor | |
after_view = symbolic_helper._reshape_helper( | |
g, | |
self, | |
g.op( | |
"Constant", | |
value_t=torch.tensor( | |
[ | |
-1, | |
output_channel, | |
upscale_factor, | |
upscale_factor, | |
dims[2], | |
dims[3], | |
] | |
), | |
), | |
allowzero=0, | |
) | |
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) | |
return symbolic_helper._reshape_helper( | |
g, | |
after_transpose, | |
g.op( | |
"Constant", | |
value_t=torch.tensor( | |
[ | |
-1, | |
output_channel, | |
dims[2] * upscale_factor, | |
dims[3] * upscale_factor, | |
] | |
), | |
), | |
allowzero=0, | |
) | |
def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): | |
dims = symbolic_helper._get_tensor_sizes(self) | |
if len(dims) != 4: | |
return symbolic_helper._unimplemented( | |
"pixel_shuffle", "only support 4d input", self | |
) | |
if any(i is None for i in dims[1:]): | |
# For dynamic input shapes, two reshapes are performed | |
reshape_h = symbolic_helper._reshape_helper( | |
g, | |
symbolic_helper._unsqueeze_helper(g, self, [3]), | |
g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), | |
allowzero=0, | |
) | |
reshape_w = symbolic_helper._reshape_helper( | |
g, | |
reshape_h, | |
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), | |
allowzero=0, | |
) | |
after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) | |
final_reshape = symbolic_helper._reshape_helper( | |
g, | |
after_transpose, | |
g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), | |
allowzero=0, | |
) | |
return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) | |
else: | |
output_channel = dims[1] * downscale_factor * downscale_factor | |
after_view = symbolic_helper._reshape_helper( | |
g, | |
self, | |
g.op( | |
"Constant", | |
value_t=torch.tensor( | |
[ | |
-1, | |
dims[1], | |
dims[2] // downscale_factor, | |
downscale_factor, | |
dims[3] // downscale_factor, | |
downscale_factor, | |
] | |
), | |
), | |
allowzero=0, | |
) | |
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) | |
return symbolic_helper._reshape_helper( | |
g, | |
after_transpose, | |
g.op( | |
"Constant", | |
value_t=torch.tensor( | |
[ | |
-1, | |
output_channel, | |
dims[2] // downscale_factor, | |
dims[3] // downscale_factor, | |
] | |
), | |
), | |
allowzero=0, | |
) | |
def _generic_rnn( | |
g: jit_utils.GraphContext, | |
variant, | |
input, | |
initial_states, | |
all_weights, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
batch_first=None, | |
batch_sizes=None, | |
): | |
warnings.warn( | |
"Exporting a model to ONNX with a batch_size other than 1, " | |
+ "with a variable length with " | |
+ variant | |
+ " can cause an error " | |
+ "when running the ONNX model with a different batch size. " | |
+ "Make sure to save the model with a batch size of 1, " | |
+ "or define the initial states (h0/c0) as inputs of the model. " | |
) | |
onnxActivations = [ | |
"Relu", | |
"Tanh", | |
"Sigmoid", | |
"Affine", | |
"LeakyRelu", | |
"ThresholdedRelu", | |
"ScaledTanh", | |
"HardSigmoid", | |
"Elu", | |
"Softsign", | |
"Softplus", | |
] | |
variantToOnnxActivationMap = dict( | |
zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) | |
) | |
weights_per_layer = 4 if has_biases else 2 | |
# this means that projections are used inside LSTM, so need to tell user that it's not supported | |
if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( | |
1 + bidirectional | |
): | |
return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) | |
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) | |
layer_weights = [ | |
all_weights[i : i + weights_per_layer] | |
for i in range(0, len(all_weights), weights_per_layer) | |
] | |
if batch_first: | |
# batch, seq, feat -> seq, batch, feat | |
input = g.op("Transpose", input, perm_i=[1, 0, 2]) | |
if dropout and train: | |
return symbolic_helper._unimplemented( | |
"RNN/GRU/LSTM", "dropout in training mode", input | |
) | |
if variant.startswith("RNN"): | |
nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] | |
variant = "RNN" | |
w_hh = all_weights[1] | |
hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) | |
if hidden_size is None: | |
return symbolic_helper._unimplemented( | |
"RNN/GRU/LSTM", "unknown hidden size", input | |
) | |
unidirectional = not bidirectional | |
prev_output = input | |
h_outs = [] | |
if variant == "RNN" or variant == "GRU": | |
h0 = initial_states | |
elif variant == "LSTM": | |
h0, c0 = initial_states | |
c_outs = [] | |
sequence_lens = unused(g) if batch_sizes is None else batch_sizes | |
if variant == "GRU": | |
# pytorch is reset, input, hidden | |
# onnx is input, reset, hidden | |
reform_permutation = [(1, 2), (0, 1), (2, 3)] | |
elif variant == "LSTM": | |
# pytorch is input, forget, cell, output. | |
# onnx is input, output, forget, cell. | |
reform_permutation = [(0, 1), (3, 4), (1, 3)] | |
def reform_weights(g, w, n, intervals): | |
slices = [ | |
symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) | |
for x, y in intervals | |
] | |
return g.op("Concat", *slices, axis_i=0) | |
def transform_weights_no_bias(layer_index): | |
weights = layer_weights[layer_index] | |
if variant == "RNN": | |
weight_ih, weight_hh = weights | |
elif variant == "GRU" or variant == "LSTM": | |
weight_ih, weight_hh = ( | |
reform_weights(g, w, hidden_size, reform_permutation) for w in weights | |
) | |
return tuple( | |
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] | |
) | |
def transform_weights(layer_index): | |
weights = layer_weights[layer_index] | |
if variant == "RNN": | |
weight_ih, weight_hh, bias_ih, bias_hh = weights | |
elif variant == "GRU" or variant == "LSTM": | |
weight_ih, weight_hh, bias_ih, bias_hh = ( | |
reform_weights(g, w, hidden_size, reform_permutation) for w in weights | |
) | |
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] | |
return tuple( | |
symbolic_helper._unsqueeze_helper(g, x, [0]) | |
for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] | |
) | |
def retrieve_state(x, start, end): | |
return ( | |
x | |
if num_layers == 1 | |
else symbolic_helper._slice_helper( | |
g, x, axes=[0], starts=[start], ends=[end] | |
) | |
) | |
for i in range(num_layers): | |
if unidirectional: | |
if weights_per_layer == 4: | |
weight_ih, weight_hh, bias_concat = transform_weights(i) | |
else: | |
weight_ih, weight_hh = transform_weights_no_bias(i) | |
bias_concat = unused(g) | |
state_indices = i, i + 1 | |
else: | |
if weights_per_layer == 4: | |
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) | |
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) | |
bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) | |
else: | |
weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) | |
weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) | |
bias_concat = unused(g) | |
weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) | |
weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) | |
state_indices = 2 * i, 2 * i + 2 | |
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] | |
inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] | |
if variant == "LSTM": | |
inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] | |
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} | |
if variant == "RNN": | |
if bidirectional: | |
activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] | |
else: | |
activation = [nonlinearity] # type: ignore[possibly-undefined] | |
prev_output, h_out = g.op( | |
"RNN", | |
*inputs, | |
outputs=2, | |
hidden_size_i=hidden_size, | |
activations_s=activation, | |
**extra_kwargs, | |
) | |
elif variant == "GRU": | |
prev_output, h_out = g.op( | |
"GRU", | |
*inputs, | |
outputs=2, | |
hidden_size_i=hidden_size, | |
linear_before_reset_i=1, | |
**extra_kwargs, | |
) | |
elif variant == "LSTM": | |
prev_output, h_out, c_out = g.op( | |
"LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs | |
) | |
if bidirectional: | |
# The ONNX RNN/GRU/LSTM produce an output of dimensions | |
# seq_len, num_directions, batch, hidden_size | |
# We have to convert to match pytorch's expected | |
# seq_len, batch, num_directions * hidden_size | |
# by first moving num_directions before hidden_size with | |
# Transpose, and then combining it with hidden_size | |
# with Reshape. | |
prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) | |
prev_output = symbolic_helper._reshape_helper( | |
g, | |
prev_output, | |
g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), | |
allowzero=0, | |
) | |
else: | |
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) | |
h_outs.append(h_out) # type: ignore[possibly-undefined] | |
if variant == "LSTM": | |
c_outs.append(c_out) # type: ignore[possibly-undefined] | |
if batch_first: | |
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size | |
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) | |
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] | |
if variant == "RNN" or variant == "GRU": | |
return prev_output, h_outs | |
elif variant == "LSTM": | |
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] | |
return prev_output, h_outs, c_outs | |
def _lstm_full( | |
g: jit_utils.GraphContext, | |
input, | |
hidden_v, | |
weight_v, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
batch_first, | |
): | |
hidden, weight = symbolic_helper._unpack_list( | |
hidden_v | |
), symbolic_helper._unpack_list(weight_v) | |
return _generic_rnn( | |
g, | |
"LSTM", | |
input, | |
hidden, | |
weight, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
batch_first, | |
) | |
def _lstm_packed( | |
g: jit_utils.GraphContext, | |
input, | |
batch_sizes, | |
hidden_v, | |
weight_v, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
): | |
hidden, weight = symbolic_helper._unpack_list( | |
hidden_v | |
), symbolic_helper._unpack_list(weight_v) | |
return _generic_rnn( | |
g, | |
"LSTM", | |
input, | |
hidden, | |
weight, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
batch_sizes=batch_sizes, | |
) | |
def lstm(g: jit_utils.GraphContext, *args): | |
if symbolic_helper._is_tensor_list(args[3]): | |
return _lstm_packed(g, *args) | |
else: | |
return _lstm_full(g, *args) | |
def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): | |
input = symbolic_helper._unsqueeze_helper(g, self, [0]) | |
hidden = symbolic_helper._unpack_list(hidden) | |
hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] | |
weight = ( | |
(w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) | |
) | |
has_biases = True if symbolic_helper._is_tensor(b_ih) else False | |
_, h_outs, c_outs = _generic_rnn( | |
g, | |
"LSTM", | |
input, | |
hidden, | |
weight, | |
has_biases, | |
num_layers=1, | |
dropout=0, | |
train=0, | |
bidirectional=False, | |
batch_first=False, | |
) | |
return symbolic_helper._squeeze_helper( | |
g, h_outs, [0] | |
), symbolic_helper._squeeze_helper(g, c_outs, [0]) | |
def _one_hidden_rnn(kind: str): | |
def _rnn_full( | |
g, | |
input, | |
hidden, | |
weight_v, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
batch_first, | |
): | |
weight = symbolic_helper._unpack_list(weight_v) | |
return _generic_rnn( | |
g, | |
kind, | |
input, | |
hidden, | |
weight, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
batch_first, | |
) | |
def _rnn_packed( | |
g, | |
input, | |
batch_sizes, | |
hidden, | |
weight_v, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
): | |
weight = symbolic_helper._unpack_list(weight_v) | |
return _generic_rnn( | |
g, | |
kind, | |
input, | |
hidden, | |
weight, | |
has_biases, | |
num_layers, | |
dropout, | |
train, | |
bidirectional, | |
batch_sizes=batch_sizes, | |
) | |
def symbolic(g, *args): | |
if symbolic_helper._is_tensor_list(args[3]): | |
return _rnn_packed(g, *args) | |
else: | |
return _rnn_full(g, *args) | |
return symbolic | |
def _dim_arange(g: jit_utils.GraphContext, like, dim): | |
like_shape = g.op("Shape", like) | |
stop = g.op( | |
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 | |
) | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.op("_caffe2::Range", stop) | |
else: | |
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) | |
return arange(g, stop, 4, None, None, None) | |
def detach(g: jit_utils.GraphContext, input): | |
# Erase aten::detach nodes because ONNX is inference only | |
return input | |
def contiguous(g: jit_utils.GraphContext, input, memory_format): | |
if memory_format > 2: # allower values are any, preserve and contiguous_format | |
raise errors.SymbolicValueError( | |
"onnx memory_format support is not implemented", input | |
) | |
return input | |
def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): | |
# Currently there is no PackPadded operator in ONNX. We rely on an | |
# optimization pass to remove this later. It is an error if all | |
# PackPadded operators cannot be optimized out. | |
if batch_first: | |
input = g.op("Transpose", input, perm_i=[1, 0, 2]) | |
if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): | |
raise errors.SymbolicValueError( | |
"'lengths' must be a Tensor for ONNX export", input | |
) | |
# We know it's a TensorType so this check is now safe. | |
# It's really only necessary because those operators expand to something that | |
# only works with int32 types in Caffe2... | |
if ( | |
_type_utils.JitScalarType.from_value( | |
lengths, _type_utils.JitScalarType.UNDEFINED | |
) | |
!= _type_utils.JitScalarType.INT | |
): | |
lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) | |
return g.op("prim::PackPadded", input, lengths, outputs=2) | |
def _pad_packed_sequence( | |
g: jit_utils.GraphContext, | |
data, | |
batch_sizes, | |
batch_first, | |
padding_value, | |
total_length, | |
): | |
# Ignore total_length as it is not supported in _symbolic_pad_packed_sequence | |
# It is only useful/used when training using data_parallel model, so | |
# It shouldn't be relevant for ONNX anyway | |
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) | |
if batch_first: | |
data = g.op("Transpose", data, perm_i=[1, 0, 2]) | |
return data, lengths | |
def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
low_i = symbolic_helper._get_const(low, "i", "low") | |
high_i = symbolic_helper._get_const(high, "i", "high") | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.INT64 | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
if low_i is None: | |
raise symbolic_helper._onnx_unsupported("randint", low) | |
if high_i is None: | |
raise symbolic_helper._onnx_unsupported("randint", high) | |
shape = symbolic_helper._maybe_get_const(shapes, "is") | |
if symbolic_helper._is_value(shape): | |
shape_const = g.op( | |
"ConstantOfShape", | |
shapes, | |
value_t=torch.tensor([0], dtype=torch.float), | |
) | |
randn = g.op( | |
"RandomUniformLike", | |
shape_const, | |
low_f=low_i, | |
high_f=high_i, | |
) | |
else: | |
randn = g.op( | |
"RandomUniform", | |
shape_i=shape, | |
low_f=low_i, | |
high_f=high_i, | |
) | |
# cast to integer type | |
int_dtype = _type_utils.JitScalarType.INT64 | |
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) | |
if int_dtype != scalar_type: | |
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) | |
return randint | |
def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
low_i = symbolic_helper._get_const(low, "i", "low") | |
high_i = symbolic_helper._get_const(high, "i", "high") | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.INT64 | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
if low_i is None: | |
raise symbolic_helper._onnx_unsupported("randint", low) | |
if high_i is None: | |
raise symbolic_helper._onnx_unsupported("randint", high) | |
randn = g.op( | |
"RandomUniformLike", | |
self, | |
low_f=low_i, | |
high_f=high_i, | |
) | |
# cast to integer type | |
int_dtype = _type_utils.JitScalarType.INT64 | |
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) | |
if int_dtype != scalar_type: | |
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) | |
return randint | |
def randn(g: jit_utils.GraphContext, shapes, dtype, *options): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.FLOAT | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
shape = symbolic_helper._maybe_get_const(shapes, "is") | |
if symbolic_helper._is_value(shape): | |
shape_const = g.op( | |
"ConstantOfShape", | |
shapes, | |
value_t=torch.tensor([0], dtype=torch.float), | |
) | |
return g.op( | |
"RandomNormalLike", | |
shape_const, | |
dtype_i=scalar_type.onnx_type(), | |
) | |
return g.op( | |
"RandomNormal", | |
shape_i=shape, | |
dtype_i=scalar_type.onnx_type(), | |
) | |
def rand(g: jit_utils.GraphContext, shapes, dtype, *options): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.FLOAT | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
shape = symbolic_helper._maybe_get_const(shapes, "is") | |
if symbolic_helper._is_value(shape): | |
shape_const = g.op( | |
"ConstantOfShape", | |
shapes, | |
value_t=torch.tensor([0], dtype=torch.float), | |
) | |
return g.op( | |
"RandomUniformLike", | |
shape_const, | |
dtype_i=scalar_type.onnx_type(), | |
) | |
return g.op( | |
"RandomUniform", | |
shape_i=shape, | |
dtype_i=scalar_type.onnx_type(), | |
) | |
def randn_like( | |
g: jit_utils.GraphContext, | |
self, | |
dtype, | |
layout=None, | |
device=None, | |
pin_memory=False, | |
memory_format=None, | |
): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if dtype is None: | |
scalar_type = _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.FLOAT | |
) | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) | |
def rand_like( | |
g: jit_utils.GraphContext, | |
self, | |
dtype, | |
layout=None, | |
device=None, | |
pin_memory=False, | |
memory_format=None, | |
): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
if dtype is None: | |
dtype = _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.FLOAT | |
) | |
return g.op( | |
"RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() | |
) | |
def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): | |
if not training: | |
slope = (upper + lower) / 2.0 | |
return g.op("LeakyRelu", input, alpha_f=slope) | |
p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) | |
return g.op("PRelu", input, p) | |
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): | |
if out is not None and not symbolic_helper._is_none(out): | |
symbolic_helper._unimplemented( | |
"Bernoulli", "out parameter is not supported for bernoulli", input | |
) | |
if generator is not None and not symbolic_helper._is_none(generator): | |
symbolic_helper._unimplemented( | |
"Bernoulli", "generator is not supported for bernoulli", input | |
) | |
dtype = _type_utils.JitScalarType.from_value( | |
input, _type_utils.JitScalarType.UNDEFINED | |
) | |
if dtype == _type_utils.JitScalarType.UNDEFINED: | |
return symbolic_helper._unimplemented( | |
"Bernoulli", "input dtype not accessible", input | |
) | |
rands = g.op( | |
"RandomUniformLike", | |
input, | |
high_f=1.0, | |
low_f=0.0, | |
dtype_i=dtype.onnx_type(), | |
) | |
prob = p if p is not None and not symbolic_helper._is_none(p) else input | |
output = g.op("Less", rands, prob) | |
return g.op("Cast", output, to_i=dtype.onnx_type()) | |
def log_sigmoid(g: jit_utils.GraphContext, input): | |
p = g.op("Sigmoid", input) | |
return g.op("Log", p) | |
def erf(g: jit_utils.GraphContext, input): | |
return g.op("Erf", input) | |
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): | |
dim = symbolic_helper._get_tensor_rank(input) | |
if dim is None: | |
return symbolic_helper._unimplemented( | |
"dim", | |
"ONNX and PyTorch use different strategies to split the input. " | |
"Input rank must be known at export time.", | |
input, | |
) | |
if dim == 0: | |
return symbolic_helper._reshape_helper(g, input, [1]) | |
if dim == 1: | |
return g.op("Identity", input) | |
# TODO: remove this as onnx opset 11 spec allows negative axes | |
if end_dim < 0: | |
end_dim = dim + end_dim | |
# use ONNX's Flatten operator for cases where the output shape is 2D | |
if start_dim == 1 and end_dim == dim - 1: | |
return g.op("Flatten", input, axis_i=start_dim) | |
if start_dim == 0 and end_dim == dim - 2: | |
return g.op("Flatten", input, axis_i=end_dim + 1) | |
return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) | |
def nonzero(g: jit_utils.GraphContext, input): | |
"""Emitted from `torch.nonzero(x, as_tuple=False)`""" | |
return t(g, g.op("NonZero", input)) | |
# Emitted from `torch.nonzero(x, as_tuple=True)` | |
def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): | |
return unbind(g, nonzero(g, input), 1, _outputs=_outputs) | |
def isnan(g: jit_utils.GraphContext, input): | |
output = g.op("IsNaN", input) | |
return output | |
def _any(g: jit_utils.GraphContext, *args): | |
# aten::any(Tensor self) | |
if len(args) == 1: | |
input = args[0] | |
dim, keepdim = None, 0 | |
# aten::any(Tensor self, int[]? dim, bool keepdim) | |
else: | |
input, dim, keepdim = args | |
# Can be int list or single int | |
dim = symbolic_helper._parse_arg(dim, "t") | |
dim = [int(d) for d in dim.view(-1)] | |
keepdim = symbolic_helper._parse_arg(keepdim, "i") | |
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) | |
input_sum = symbolic_helper._reducesum_helper( | |
g, input, axes_i=dim, keepdims_i=keepdim | |
) | |
return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) | |
def _all(g: jit_utils.GraphContext, *args): | |
input = g.op("Not", args[0]) | |
# aten::all(Tensor self) | |
if len(args) == 1: | |
return g.op("Not", _any(g, input)) | |
# aten::all(Tensor self, int[]? dim, bool keepdim) | |
else: | |
return g.op("Not", _any(g, input, args[1], args[2])) | |
def narrow(g: jit_utils.GraphContext, input, dim, start, length): | |
return symbolic_helper._slice_helper( | |
g, input, axes=[dim], starts=[start], ends=[start + length] | |
) | |
def argmax( | |
g: jit_utils.GraphContext, | |
input: torch._C.Value, | |
dim: torch._C.Value, | |
keepdim: bool, | |
): | |
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") | |
def argmin( | |
g: jit_utils.GraphContext, | |
input: torch._C.Value, | |
dim: torch._C.Value, | |
keepdim: bool, | |
): | |
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") | |
def scatter(g: jit_utils.GraphContext, self, dim, index, src): | |
src_type = _type_utils.JitScalarType.from_value( | |
src, _type_utils.JitScalarType.UNDEFINED | |
) | |
src = symbolic_helper._maybe_get_scalar(src) | |
if symbolic_helper._is_value(src): | |
return g.op("Scatter", self, index, src, axis_i=dim) | |
else: | |
# Check if scalar "src" has same type as self (PyTorch allows different | |
# type for scalar src (but not when src is tensor)). If not, insert Cast node. | |
self_scalar_type = _type_utils.JitScalarType.from_value(self) | |
if self_scalar_type != src_type: | |
src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) | |
return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) | |
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): | |
scalar_type = symbolic_helper._try_get_scalar_type(self) | |
if scalar_type is None: | |
return symbolic_helper._unimplemented( | |
"scatter_add", "input dtype not accessible", self | |
) | |
sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) | |
if sizes: | |
to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) | |
else: | |
to_add = zeros_like(g, self, scalar_type) | |
to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) | |
return add(g, self, to_add) | |
def log2(g: jit_utils.GraphContext, self): | |
_ln2 = 0.693147180559945309 | |
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) | |
def is_floating_point(g: jit_utils.GraphContext, self): | |
if symbolic_helper._is_fp(self): | |
return g.op("Constant", value_t=torch.BoolTensor([1])) | |
return g.op("Constant", value_t=torch.BoolTensor([0])) | |
def __is_(g: jit_utils.GraphContext, self, other): | |
if symbolic_helper._is_none(other): | |
if symbolic_helper._is_none(self): | |
return g.op("Constant", value_t=torch.BoolTensor([1])) | |
return g.op("Constant", value_t=torch.BoolTensor([0])) | |
return eq(g, self, other) | |
def __isnot_(g: jit_utils.GraphContext, self, other): | |
return __is_(g, self, other) | |
def one_hot(g: jit_utils.GraphContext, self, num_classes): | |
values = g.op("Constant", value_t=torch.LongTensor([0, 1])) | |
# onnxruntime supports limited type combinations for OneHot. | |
if _type_utils.JitScalarType.from_value( | |
num_classes, _type_utils.JitScalarType.UNDEFINED | |
) in { | |
_type_utils.JitScalarType.UINT8, | |
_type_utils.JitScalarType.INT8, | |
_type_utils.JitScalarType.INT, | |
_type_utils.JitScalarType.INT16, | |
}: | |
num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) | |
return g.op("OneHot", self, num_classes, values, axis_i=-1) | |
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): | |
if symbolic_helper._maybe_get_const(sparse_grad, "i"): | |
return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) | |
# NOTE: This workaround is needed since GatherElement is only supported | |
# since opset 11, and Gather in ONNX is not the same as torch.gather. | |
scalar_type = _type_utils.JitScalarType.from_value(self) | |
values = g.op("Constant", value_t=torch.LongTensor([0, 1])) | |
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) | |
index = g.op( | |
"Cast", | |
g.op("OneHot", index, depth, values, axis_i=dim), | |
to_i=scalar_type.onnx_type(), | |
) | |
mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) | |
return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) | |
def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): | |
if dim is None: | |
mean = g.op("ReduceMean", input, keepdims_i=0) | |
t_mean = mean | |
num_elements = numel(g, input) | |
else: | |
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) | |
t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) | |
redudced_dims = g.op("Shape", input) | |
# dim could contain one or multiple dimensions | |
redudced_dims = g.op( | |
"Gather", | |
redudced_dims, | |
g.op("Constant", value_t=torch.tensor(dim)), | |
axis_i=0, | |
) | |
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) | |
sub_v = g.op("Sub", input, t_mean) | |
sqr_sub = g.op("Mul", sub_v, sub_v) | |
keepdim_mean = 0 if dim is None else keepdim | |
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) | |
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N | |
if correction is None: | |
correction = 1 | |
if correction != 0: | |
num_elements = g.op( | |
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT | |
) | |
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) | |
mul = g.op("Mul", var, num_elements) | |
var = g.op("Div", mul, g.op("Sub", num_elements, one)) | |
return var, mean | |
def std(g: jit_utils.GraphContext, input, *args): | |
var, _ = var_mean(g, input, *args) | |
return g.op("Sqrt", var) | |
def var(g: jit_utils.GraphContext, input, *args): | |
var, _ = var_mean(g, input, *args) | |
return var | |
def var_mean(g: jit_utils.GraphContext, input, *args): | |
# var_mean (and all variance-related functions) has multiple signatures, so need to manually figure | |
# out the correct arguments: | |
# aten::var_mean(Tensor self, bool unbiased) | |
# aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False) | |
# aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) | |
if len(args) == 1: | |
return _var_mean(g, input, None, args[0], None) | |
else: | |
return _var_mean(g, input, *args) | |
def std_mean(g: jit_utils.GraphContext, input, *args): | |
var, mean = var_mean(g, input, *args) | |
return g.op("Sqrt", var), mean | |
def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): | |
return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) | |
def arange(g: jit_utils.GraphContext, *args): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("arange", *args) | |
def _get_arange_dtype(dtype): | |
dtype = symbolic_helper._maybe_get_const(dtype, "i") | |
return dtype | |
def _float_step_convert(range_tensor): | |
if symbolic_helper._is_fp(range_tensor): | |
range_tensor = g.op( | |
"Cast", | |
g.op("Ceil", range_tensor), | |
to_i=_type_utils.JitScalarType.INT64.onnx_type(), | |
) | |
return range_tensor | |
if len(args) == 2 or len(args) == 5: | |
if len(args) == 2: | |
# aten::arange(Scalar end, Tensor out) | |
dtype = None | |
else: | |
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) | |
dtype = _get_arange_dtype(args[1]) | |
dtype, end, start, step = symbolic_helper._arange_cast_helper( | |
g, end=args[0], dtype=dtype | |
) | |
end = symbolic_helper._unsqueeze_helper(g, end, [0]) | |
range_tensor = _float_step_convert(end) | |
arange_tensor = symbolic_helper._squeeze_helper( | |
g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] | |
) | |
return g.op( | |
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() | |
) | |
elif len(args) == 4 or len(args) == 7: | |
if len(args) == 4: | |
# aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) | |
dtype = None | |
else: | |
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) | |
dtype = _get_arange_dtype(args[3]) | |
dtype, end, start, step = symbolic_helper._arange_cast_helper( | |
g, start=args[0], end=args[1], step=args[2], dtype=dtype | |
) | |
step = symbolic_helper._unsqueeze_helper(g, step, [0]) | |
end = symbolic_helper._unsqueeze_helper(g, end, [0]) | |
start = symbolic_helper._unsqueeze_helper(g, start, [0]) | |
range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) | |
arange_tensor = symbolic_helper._squeeze_helper( | |
g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] | |
) | |
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) | |
return g.op( | |
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() | |
) | |
elif len(args) == 6: | |
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) | |
dtype = _get_arange_dtype(args[2]) | |
dtype, end, start, step = symbolic_helper._arange_cast_helper( | |
g, start=args[0], end=args[1], dtype=dtype | |
) | |
end = symbolic_helper._unsqueeze_helper(g, end, [0]) | |
start = symbolic_helper._unsqueeze_helper(g, start, [0]) | |
range_tensor = _float_step_convert(g.op("Sub", end, start)) | |
arange_tensor = g.op( | |
"Add", | |
symbolic_helper._squeeze_helper( | |
g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] | |
), | |
start, | |
) | |
return g.op( | |
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() | |
) | |
return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") | |
def linspace( | |
g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory | |
): | |
range_tensor = symbolic_helper._arange_helper(g, steps, None) | |
step = div( | |
g, | |
sub(g, end, start), | |
sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), | |
) | |
return add(g, mul(g, range_tensor, step), start) | |
def lift(g: jit_utils.GraphContext, self): | |
# at::lift() is a no-op from the perspective of tracing for onnx | |
return self | |
def masked_fill(g: jit_utils.GraphContext, self, mask, value): | |
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) | |
value = symbolic_helper._maybe_get_scalar(value) | |
return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) | |
def masked_fill_(g: jit_utils.GraphContext, self, mask, value): | |
return masked_fill(g, self, mask, value) | |
def index(g: jit_utils.GraphContext, self, index): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("index", self, index, overload_name="Tensor") | |
if symbolic_helper._is_packed_list(index): | |
indices = symbolic_helper._unpack_list(index) | |
else: | |
indices = [index] | |
def try_mask_to_index(index): | |
if not symbolic_helper._is_none(index) and ( | |
_type_utils.JitScalarType.from_value( | |
index, _type_utils.JitScalarType.UNDEFINED | |
) | |
== _type_utils.JitScalarType.UINT8 | |
or symbolic_helper._is_bool(index) | |
): | |
if g.opset < 9: | |
raise errors.SymbolicValueError( | |
"Exporting masked indices are only supported after ONNX opset 9.", | |
self, | |
) | |
warnings.warn( | |
"Exporting aten::index operator with indices of type Byte. " | |
"Only 1-D indices are supported. In any other case, " | |
"this will produce an incorrect ONNX graph." | |
) | |
index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) | |
return index | |
indices = [try_mask_to_index(idx) for idx in indices] | |
if len(indices) == 1: | |
return symbolic_helper._select_helper( | |
g, self, 0, indices[0], apply_reshape=False | |
) | |
else: | |
# Multiple tensors as indices. Each tensor could either be | |
# 1. prim::Constant() | |
# representing ":" in python indexing. E.g. tensor[:, :] | |
# 2. prim::Constant[value=...] or tensor output | |
# representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. | |
# For more info on advanced indexing, | |
# check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing | |
# Consider a general case of | |
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] | |
# where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". | |
# Same results can be achieved through transposing t into | |
# t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] | |
# and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t | |
# and process the tensor indices. | |
# t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] | |
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) | |
# After gather, reshape and transpose back. | |
adv_idx_indices = [ | |
i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) | |
] | |
if len(adv_idx_indices) == 0: | |
return self | |
elif len(adv_idx_indices) == 1: | |
return index_select( | |
g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] | |
) | |
else: | |
rank = symbolic_helper._get_tensor_rank(self) | |
if rank is None: | |
return symbolic_helper._unimplemented( | |
"aten::index", | |
"operator of advanced indexing on tensor of unknown rank. " | |
"Try turning on shape inference during export: " | |
"torch.onnx._export(..., onnx_shape_inference=True).", | |
self, | |
) | |
# TODO: If indexing is supported natively in ONNX in future opsets, | |
# update the warning to recommend exporting with higher opset version. | |
warnings.warn( | |
"Exporting aten::index operator of advanced indexing in opset " | |
f"{GLOBALS.export_onnx_opset_version}" | |
" is achieved by combination of multiple ONNX operators, " | |
"including Reshape, Transpose, Concat, and Gather. " | |
"If indices include negative values, the exported graph will produce incorrect results." | |
) | |
adv_idx_count = len(adv_idx_indices) | |
shape_tensor = _shape_as_tensor(g, self) | |
dim_tensor_list = [ | |
g.op( | |
"Gather", | |
shape_tensor, | |
g.op("Constant", value_t=torch.LongTensor([dim])), | |
axis_i=0, | |
) | |
for dim in range(rank) | |
] | |
self = g.op( | |
"Transpose", | |
self, | |
perm_i=adv_idx_indices | |
+ [i for i in range(rank) if i not in adv_idx_indices], | |
) | |
self = g.op("Flatten", self, axis_i=adv_idx_count) | |
# Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. | |
cum_adv_index = indices[adv_idx_indices[-1]] | |
multiplier = dim_tensor_list[adv_idx_indices[-1]] | |
for i in range(adv_idx_count - 2, -1, -1): | |
adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) | |
cum_adv_index = g.op("Add", cum_adv_index, adv_index) | |
multiplier = g.op( | |
"Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] | |
) | |
# perform gather | |
self = index_select(g, self, 0, cum_adv_index) | |
cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) | |
# check if all advanced indices are consecutive. | |
# Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing | |
# to understand how the subarray position is decided. | |
if adv_idx_indices == list( | |
range(adv_idx_indices[0], adv_idx_indices[-1] + 1) | |
): | |
# unfold regular index axes | |
folded_adv_idx_shape_list = [ | |
g.op("Constant", value_t=torch.LongTensor([-1])) | |
] + [ | |
dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices | |
] | |
folded_adv_idx_shape = g.op( | |
"Concat", *folded_adv_idx_shape_list, axis_i=0 | |
) | |
self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) | |
# Transpose folded advanced indexed axis to its original location. | |
adv_idx_permute = ( | |
list(range(1, adv_idx_indices[0] + 1)) | |
+ [0] | |
+ list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) | |
) | |
self = g.op("Transpose", self, perm_i=adv_idx_permute) | |
# unfold advanced index axes | |
final_shape_list = ( | |
[dim_tensor_list[i] for i in range(adv_idx_indices[0])] | |
+ [cum_adv_index_shape_tensor] | |
+ [ | |
dim_tensor_list[i] | |
for i in range(adv_idx_indices[0], rank) | |
if i not in adv_idx_indices | |
] | |
) | |
final_shape = g.op("Concat", *final_shape_list, axis_i=0) | |
else: | |
final_shape = g.op( | |
"Concat", | |
cum_adv_index_shape_tensor, | |
*[ | |
dim_tensor_list[i] | |
for i in range(rank) | |
if i not in adv_idx_indices | |
], | |
axis_i=0, | |
) | |
return symbolic_helper._reshape_helper(g, self, final_shape) | |
def linalg_norm( | |
g: jit_utils.GraphContext, | |
self: torch._C.Value, | |
ord: torch._C.Value, | |
dim: Optional[Sequence[int]], | |
keepdim: bool, | |
dtype: torch._C.Value, | |
): | |
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html | |
ord_value = None | |
if dim is None: | |
if symbolic_helper._is_none(ord): | |
self = symbolic_helper._reshape_helper(g, self, [-1]) | |
ord = g.op("Constant", value_t=torch.LongTensor([2])) | |
self_dim = symbolic_helper._get_tensor_rank(self) | |
if self_dim is None: | |
return symbolic_helper._unimplemented( | |
"dim", "Input rank must be known at export time.", self | |
) | |
if self_dim == 1: | |
ord_value = symbolic_helper._parse_arg(ord, "f") | |
else: | |
dim = [0, 1] | |
else: | |
if len(dim) == 1: | |
if symbolic_helper._is_none(ord): | |
ord = g.op("Constant", value_t=torch.LongTensor([2])) | |
ord_value = symbolic_helper._parse_arg(ord, "f") | |
if ord_value: | |
return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) | |
return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) | |
def linalg_vector_norm( | |
g: jit_utils.GraphContext, | |
self: torch._C.Value, | |
ord: float, | |
dim: Optional[Sequence[int]], | |
keepdim: bool, | |
dtype: torch._C.Value, | |
): | |
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html | |
if symbolic_helper._is_none(dim): | |
self = symbolic_helper._reshape_helper(g, self, [-1]) | |
keepdim = False | |
if ord == math.inf: | |
result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim) | |
elif ord == -math.inf: | |
result = g.op("ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim) | |
elif ord == 0: | |
return symbolic_helper._onnx_opset_unsupported_detailed( | |
"linalg_vector_norm", 9, 11, "ord=0 not supported", self | |
) | |
elif ord == 1: | |
result = _reduce_op_symbolic("ReduceL1")(g, self, dim=dim, keepdim=keepdim) | |
elif ord == 2: | |
result = _reduce_op_symbolic("ReduceL2")(g, self, dim=dim, keepdim=keepdim) | |
else: | |
ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) | |
result = symbolic_helper._reducesum_helper( | |
g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim | |
) | |
result = g.op( | |
"Pow", | |
result, | |
g.op( | |
"Div", | |
g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), | |
ord_op, | |
), | |
) | |
if not symbolic_helper._is_none(dtype): | |
dtype = symbolic_helper._get_const(dtype, "i", "dtype") | |
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type] | |
return result | |
def linalg_matrix_norm( | |
g: jit_utils.GraphContext, | |
self: torch._C.Value, | |
ord: torch._C.Value, | |
dim: List[int], | |
keepdim: bool, | |
dtype: torch._C.Value, | |
): | |
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html | |
ord_value = symbolic_helper._parse_arg(ord, "s") | |
if ord_value == "fro": | |
return frobenius_norm(g, self, dim, keepdim) | |
elif ord_value == "nuc": | |
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) | |
else: | |
ord_value = symbolic_helper._parse_arg(ord, "f") | |
if ord_value is None: | |
return frobenius_norm(g, self, dim, keepdim) | |
if ord_value == 2 or ord_value == -2: | |
# ord = 2/-2 unimplemented due to lack of operators | |
# used to calculate singular values | |
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) | |
# Wrap the dim vector to handle negative dim values | |
self_dim = symbolic_helper._get_tensor_rank(self) | |
if self_dim is None: | |
return symbolic_helper._unimplemented( | |
"linalg.matrix_norm", "Input rank must be known at export time.", self | |
) | |
# Common implementation for cases with | |
# ord = 1/-1 and ord = inf/-inf | |
if dim[0] < 0: | |
dim[0] += self_dim | |
if dim[1] < 0: | |
dim[1] += self_dim | |
if ord_value == math.inf or ord_value == -math.inf: | |
dim[0], dim[1] = dim[1], dim[0] | |
if dim[1] > dim[0] and not keepdim: | |
dim[1] -= 1 | |
sum = symbolic_helper._reducesum_helper( | |
g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim | |
) | |
if ord_value > 0: | |
result, indices = max( | |
g, | |
sum, | |
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), | |
keepdim=keepdim, | |
) | |
else: | |
result, indices = min( | |
g, | |
sum, | |
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), | |
keepdim=keepdim, | |
) | |
return result | |
def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): | |
return cross(g, input, other, dim) | |
def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): | |
sqr = g.op("Mul", self, self) | |
sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) | |
return g.op("Sqrt", sumsqr) | |
def multinomial( | |
g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None | |
): | |
if generator is not None and not symbolic_helper._is_none(generator): | |
symbolic_helper._unimplemented( | |
"Multinomial", "generator is not supported for multinomial", input | |
) | |
if not replacement and num_samples > 1: | |
symbolic_helper._unimplemented( | |
"Multinomial", | |
"replacement=False when num_samples > 1 is not supported for multinomial", | |
input, | |
) | |
log_input = log(g, input) | |
return g.op( | |
"Multinomial", | |
log_input, | |
dtype_i=_C_onnx.TensorProtoDataType.INT64, | |
sample_size_i=num_samples, | |
) | |
def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): | |
scalar_type = _type_utils.JitScalarType.from_value(self) | |
batch_mul = matmul(g, batch1, batch2) | |
mul_a = mul( | |
g, | |
batch_mul, | |
g.op("Cast", alpha, to_i=scalar_type.onnx_type()), | |
) | |
mul_b = mul( | |
g, | |
self, | |
g.op("Cast", beta, to_i=scalar_type.onnx_type()), | |
) | |
return add(g, mul_a, mul_b) | |
def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: Optional[str] = None): | |
if indexing is None: | |
indexing = "ij" | |
elif indexing not in {"ij", "xy"}: | |
raise errors.SymbolicValueError( | |
f"Unsupported indexing: {indexing}", tensor_list | |
) | |
unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) | |
if indexing == "xy": | |
unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] | |
tensors = [ | |
symbolic_helper._reshape_helper( | |
g, t, g.op("Constant", value_t=torch.LongTensor([-1])) | |
) | |
for t in unpacked_tensor_list | |
] | |
tensors_shape = [g.op("Shape", t) for t in tensors] | |
out_shape = g.op("Concat", *tensors_shape, axis_i=0) | |
out = [] | |
for i, t in enumerate(tensors): | |
shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( | |
tensors | |
) | |
shape_i[i] = tensors_shape[i] | |
t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) | |
out.append(g.op("Expand", t_reshaped, out_shape)) | |
if indexing == "xy": | |
out[0], out[1] = out[1], out[0] | |
return g.op("prim::ListConstruct", *out) | |
def remainder(g: jit_utils.GraphContext, input, other): | |
div = _floor_divide(g, input, other) | |
quo = g.op("Mul", div, other) | |
return g.op("Sub", input, quo) | |
def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): | |
if approximate == "tanh": | |
kBeta = math.sqrt(2 / math.pi) | |
kKappa = 0.044715 | |
beta = torch.tensor(kBeta, dtype=torch.double) | |
kappa = torch.tensor(kKappa, dtype=torch.double) | |
one = torch.tensor(1.0, dtype=torch.double) | |
half = torch.tensor(0.5, dtype=torch.double) | |
self_cube = mul(g, self, mul(g, self, self)) | |
inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) | |
return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) | |
else: | |
_sqrt2 = 1.4142135623730951 | |
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) | |
erf_plusone = add( | |
g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) | |
) | |
return mul( | |
g, | |
mul(g, self, erf_plusone), | |
g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), | |
) | |
def group_norm( | |
g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled | |
): | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at( | |
"group_norm", | |
input, | |
weight, | |
bias, | |
num_groups_i=num_groups, | |
eps_f=eps, | |
cudnn_enabled_i=cudnn_enabled, | |
) | |
channel_size = symbolic_helper._get_tensor_dim_size(input, 1) | |
if channel_size is not None: | |
assert channel_size % num_groups == 0 | |
input_rank = symbolic_helper._get_tensor_rank(input) | |
if input_rank is None: | |
return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) | |
# 0 in the shape list keeps dimension value unchanged. | |
shape = [0, num_groups, -1] | |
input_reshaped = symbolic_helper._reshape_helper( | |
g, input, g.op("Constant", value_t=torch.LongTensor(shape)) | |
) | |
# C is always divisible by num_groups | |
# Due to shape difference. we need to apply weight and bias after | |
# instance norm computation and reshape | |
weight_ = g.op( | |
"Constant", | |
value_t=torch.tensor( | |
[1.0] * num_groups, | |
dtype=_type_utils.JitScalarType.from_value(input).dtype(), | |
), | |
) | |
bias_ = g.op( | |
"Constant", | |
value_t=torch.tensor( | |
[0.0] * num_groups, | |
dtype=_type_utils.JitScalarType.from_value(input).dtype(), | |
), | |
) | |
norm_reshaped = g.op( | |
"InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps | |
) | |
norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) | |
if weight is None or weight.node().mustBeNone(): | |
weight_value = torch.tensor( | |
[1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() | |
) | |
weight = g.op("Constant", value_t=weight_value) | |
if bias is None or bias.node().mustBeNone(): | |
bias_value = torch.tensor( | |
[0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() | |
) | |
bias = g.op("Constant", value_t=bias_value) | |
# Norm has shape [N, C, *] so we reshape weight and bias to [C, *] | |
axes = list(range(1, input_rank - 1)) | |
return add( | |
g, | |
mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), | |
symbolic_helper._unsqueeze_helper(g, bias, axes), | |
) | |
def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): | |
rank = symbolic_helper._get_tensor_rank(weight_v) | |
if rank is not None: | |
# W = g * ((v) / ||v||) | |
# Compute norm_except_dim for l2 norm. dim = None means over all dims | |
# torch's weight_norm module sets dim = -1 if it's None. | |
# This conflicts the logic for negative axes to access dims backwards | |
# TODO: Might need a fix in torch group_norm module | |
axes = list(range(rank)) | |
if dim is not None: | |
if dim < -1: | |
dim += rank | |
if dim != -1: | |
axes.remove(dim) | |
norm_v = norm(g, weight_v, 2, axes, 1) | |
div = g.op("Div", weight_v, norm_v) | |
return g.op("Mul", div, weight_g) | |
if symbolic_helper.is_caffe2_aten_fallback(): | |
return g.at("_weight_norm", weight_v, weight_g, dim_i=dim) | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", | |
weight_v, | |
) | |
def dim(g: jit_utils.GraphContext, self): | |
"""Implement the dim functionality available for a pytorch tensor in ONNX""" | |
# ONNX does not support dim directly in this opset so we can use 2 ops to get the info | |
shape = g.op("Shape", self) | |
return g.op("Size", shape) | |
def __contains_(g: jit_utils.GraphContext, self, element): | |
unpacked_list = symbolic_helper._unpack_list(self) | |
if all( | |
symbolic_helper._is_constant(x) for x in unpacked_list | |
) and symbolic_helper._is_constant(element): | |
return g.op( | |
"Constant", | |
value_t=torch.tensor( | |
symbolic_helper._node_get(element.node(), "value") | |
in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) | |
), | |
) | |
raise errors.SymbolicValueError( | |
"Unsupported: ONNX export of __contains__ for non-constant list or element.", | |
self, | |
) | |
def __getitem_(g: jit_utils.GraphContext, self, i): | |
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) | |
def item(g: jit_utils.GraphContext, self): | |
return self | |
def take(g: jit_utils.GraphContext, self, index): | |
self_flattened = symbolic_helper._reshape_helper( | |
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) | |
) | |
out = index_select(g, self_flattened, 0, index) | |
out = reshape_as(g, out, index) | |
return out | |
def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): | |
diff_ = sub(g, target, input) | |
exp_ = exp(g, target) | |
output = mul(g, exp_, diff_) | |
return output | |
def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): | |
log_ = log(g, target) | |
diff_ = sub(g, log_, input) | |
output_pos = mul(g, target, diff_) | |
zeros_ = zeros_like(g, output_pos) | |
mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) | |
output = where(g, mask_, output_pos, zeros_) | |
return output | |
def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): | |
if log_target: | |
output = _kl_div_log_target_impl(g, input, target) | |
else: | |
output = _kl_div_non_log_target_impl(g, input, target) | |
if reduction == 0: | |
return output | |
elif reduction == 1: | |
return g.op("ReduceMean", output, keepdims_i=0) | |
elif reduction == 2: | |
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) | |
else: | |
return symbolic_helper._onnx_unsupported( | |
"kl_div with reduction other than none, mean, or sum.", input | |
) | |
def mse_loss(g: jit_utils.GraphContext, input, target, reduction): | |
output = mul(g, sub(g, input, target), sub(g, input, target)) | |
if reduction == 0: | |
return output | |
elif reduction == 1: | |
return g.op("ReduceMean", output, keepdims_i=0) | |
elif reduction == 2: | |
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) | |
else: | |
return symbolic_helper._onnx_unsupported( | |
"mse_loss with reduction other than none, mean, or sum.", input | |
) | |
def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): | |
sizes = symbolic_helper._maybe_get_const(sizes, "is") | |
rank = len(strides) | |
self_1d = symbolic_helper._reshape_helper( | |
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) | |
) | |
ind: Optional[torch.Tensor] | |
if not symbolic_helper._is_value(sizes): | |
ind = torch.tensor([0], dtype=torch.long) | |
for i, (size, stride) in enumerate(zip(sizes, strides)): | |
r_size = [1] * rank | |
r_size[i] = -1 | |
ind = ind + torch.arange(size).view(r_size) * stride | |
if offset: | |
ind = ind + offset | |
return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) | |
else: | |
ind = None | |
for i, stride in enumerate(strides): | |
r_size = [1] * rank | |
r_size[i] = -1 | |
size = select( | |
g, | |
sizes, | |
g.op("Constant", value_t=torch.tensor([0])), | |
g.op("Constant", value_t=torch.tensor(i)), | |
) | |
tmp_ind = symbolic_helper._reshape_helper( | |
g, | |
arange(g, size, 4, None, None, None), | |
g.op("Constant", value_t=torch.tensor(r_size)), | |
) | |
tmp_ind = g.op( | |
"Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) | |
) | |
if ind is None: | |
ind = tmp_ind | |
else: | |
ind = g.op("Add", ind, tmp_ind) | |
if offset: | |
ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) | |
return g.op("Gather", self_1d, ind) | |
def __derive_index(g: jit_utils.GraphContext, index, start, step): | |
return g.op("Add", start, g.op("Mul", index, step)) | |
# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp | |
# if (step > 0 && lo < hi) { | |
# push(stack, 1 + (hi - 1 - lo) / step); | |
# } else if (step < 0 && lo > hi) { | |
# push(stack, 1 + (lo - 1 - hi) / (0 - step)); | |
# } else { | |
# push(stack, 0); | |
# } | |
def __range_length(g: jit_utils.GraphContext, lo, hi, step): | |
sub = g.op("Sub", hi, lo) | |
div = g.op("Ceil", true_divide(g, sub, step)) | |
return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) | |
def linear(g: jit_utils.GraphContext, input, weight, bias): | |
rank = symbolic_helper._get_tensor_rank(input) | |
weight = t(g, weight) | |
if rank == 2 and not bias.node().mustBeNone(): | |
alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) | |
beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) | |
output = addmm(g, bias, input, weight, alpha, beta) | |
else: | |
output = matmul(g, input, weight) | |
if not bias.node().mustBeNone(): | |
output = add(g, bias, output) | |
return output | |
def hann_window( | |
g: jit_utils.GraphContext, | |
window_length, | |
periodic=True, | |
dtype: Optional[int] = None, | |
layout=None, | |
device=None, | |
pin_memory=None, | |
requires_grad=False, | |
): | |
if dtype is None: | |
dtype_ = torch.get_default_dtype() | |
if not dtype_ or not dtype_.is_floating_point: | |
dtype_ = torch.float | |
scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) | |
else: | |
scalar_type = _type_utils.JitScalarType(dtype) | |
n_array = arange(g, window_length, 4, None, None, None) | |
output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) | |
output = mul( | |
g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output | |
) | |
if periodic is False: | |
window_length = sub( | |
g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) | |
) | |
output = div(g, output, window_length) | |
output = g.op( | |
"Cast", | |
square(g, sin(g, output)), | |
to_i=scalar_type.onnx_type(), | |
) | |
return output | |
def mv(g: jit_utils.GraphContext, self, vec): | |
return matmul(g, self, vec) | |
def dot(g: jit_utils.GraphContext, self, other): | |
return matmul(g, self, other) | |
def movedim(g: jit_utils.GraphContext, self, source, destination): | |
# This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim | |
source = source.view(-1) | |
destination = destination.view(-1) | |
assert source.size() == destination.size() | |
if (source == destination).all(): | |
return self | |
self_rank = symbolic_helper._get_tensor_rank(self) | |
assert self_rank is not None | |
perm = list(range(self_rank)) | |
src_dims = perm.copy() | |
dst_dims = perm.copy() | |
for src, dst in zip(source.tolist(), destination.tolist()): | |
perm[dst] = src | |
src_dims[src] = -1 | |
dst_dims[dst] = -1 | |
src_dims = [dim for dim in src_dims if dim != -1] | |
dst_dims = [dim for dim in dst_dims if dim != -1] | |
for src, dst in zip(src_dims, dst_dims): | |
perm[dst] = src | |
return g.op("Transpose", self, perm_i=perm) | |
def fill(g: jit_utils.GraphContext, self, value): | |
scalar_type = _type_utils.JitScalarType.from_value( | |
self, _type_utils.JitScalarType.FLOAT | |
) | |
return full_like(g, self, value, scalar_type) | |
def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): | |
warnings.warn( | |
"Warning: ONNX export does not support duplicated values in 'index' field, " | |
+ "this will cause the ONNX model to be incorrect." | |
) | |
# ONNX does not support "alpha" argument, unlike aten index_add | |
# See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context | |
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: | |
return symbolic_helper._unimplemented("index_add", "alpha != 1", self) | |
dim = symbolic_helper._maybe_get_const(dim, "i") | |
if dim is None: | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting 'index_add_()' function with " | |
"unknown 'dim' value.", | |
self, | |
) | |
self_dim_rank = symbolic_helper._get_tensor_rank(self) | |
other_dim_rank = symbolic_helper._get_tensor_rank(other) | |
if self_dim_rank is None or other_dim_rank is None: | |
raise errors.SymbolicValueError( | |
"ONNX export does NOT support exporting 'index_add_()' function while " | |
"the rank of self tensor or tensor to be added is unknown.", | |
self, | |
) | |
if other_dim_rank != self_dim_rank: | |
delta = self_dim_rank - other_dim_rank | |
for i in range(delta): | |
other = symbolic_helper._unsqueeze_helper( | |
g, other, [symbolic_helper._get_tensor_rank(other)] | |
) | |
other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) | |
self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) | |
if (other_dim_size is not None) and (self_dim_size is not None): | |
if other_dim_size > self_dim_size: | |
raise errors.SymbolicValueError( | |
"ONNX export does not support exporting 'index_add_()' function with " | |
"duplicated values in 'index' parameter yet.", | |
self, | |
) | |
# Construct a new shape. It's almost as same as self except the size of the 'dim' | |
# dimension is 1, so that we can expand other dimensions as expected. | |
new_shape_axes = list(range(self_dim_rank)) | |
new_shape_starts = [0 for i in range(self_dim_rank)] | |
new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] | |
new_shape = symbolic_helper._slice_helper( | |
g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends | |
) | |
other = expand_as(g, other, new_shape) | |
for i in range(dim): | |
index = symbolic_helper._unsqueeze_helper(g, index, [0]) | |
for i in range(self_dim_rank - dim - 1): | |
index = symbolic_helper._unsqueeze_helper( | |
g, index, [symbolic_helper._get_tensor_rank(index)] | |
) | |
return scatter_add(g, self, dim, expand_as(g, index, other), other) | |
def roll(g: jit_utils.GraphContext, self, shifts, dims): | |
assert len(shifts) == len(dims) | |
result = self | |
for i in range(len(shifts)): | |
shapes = [] | |
shape = symbolic_helper._slice_helper( | |
g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] | |
) | |
shapes.append(shape) | |
shape = symbolic_helper._slice_helper( | |
g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] | |
) | |
shapes.append(shape) | |
result = g.op("Concat", *shapes, axis_i=dims[i]) | |
return result | |
def cross(g: jit_utils.GraphContext, input, other, dim=None): | |
dim = symbolic_helper._get_dim_for_cross(input, dim) | |
# If we have two tensors such that | |
# A = [a, b, c], B = [d, e, f], we permute the tensor such that we have | |
# After first roll, | |
# A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) | |
roll_x_1 = roll(g, input, [2], [dim]) | |
roll_y_1 = roll(g, other, [1], [dim]) | |
# After second roll, | |
# A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) | |
roll_x_2 = roll(g, input, [1], [dim]) | |
roll_y_2 = roll(g, other, [2], [dim]) | |
# cross product is calculated as | |
# result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] | |
return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) | |
def cdist( | |
g: jit_utils.GraphContext, | |
x1, | |
x2, | |
p=2.0, | |
compute_mode="use_mm_for_euclid_dist_if_necessary", | |
): | |
# X1.shape = (B * P * D), X2.shape = (B * R * D) | |
# In order to respect numpy style broadcasting as demonstrated in | |
# https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md | |
# we unsqueeze both input tensors | |
# Currently we ignore the 'compute_mode' variable as we use default to | |
# using matrix multiplication to calculate the euclidean distance | |
rank = symbolic_helper._get_tensor_rank(x1) | |
assert rank is not None | |
broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) | |
broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) | |
return pairwise_distance( | |
g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False | |
) | |
def lerp(g: jit_utils.GraphContext, self, end, weight): | |
# Conditional for better numeric. This has been discussed in | |
# https://github.com/pytorch/pytorch/pull/18871 | |
diff = g.op("Sub", end, self) | |
return where( | |
g, | |
g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), | |
g.op("Add", self, g.op("Mul", weight, diff)), | |
g.op( | |
"Sub", | |
end, | |
g.op( | |
"Mul", | |
diff, | |
g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), | |
), | |
), | |
) | |
def broadcast_tensors(g: jit_utils.GraphContext, self): | |
all_tensors = symbolic_helper._unpack_list(self) | |
t_with_final_shape = zeros_like(g, all_tensors[0]) | |
# Add operator supports multidirectional broadcasting. So we leverage this function | |
# to infer the final shape generated by the broadcast. | |
for t in all_tensors: | |
t_with_final_shape = add(g, t_with_final_shape, t) | |
t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] | |
return g.op("prim::ListConstruct", *t_list) | |
def is_pinned(g: jit_utils.GraphContext, self, device=None): | |
# Unused by ONNX. | |
return None | |
def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): | |
size = symbolic_helper._get_tensor_dim_size(self, dim) | |
if size is None: | |
return symbolic_helper._unimplemented( | |
"prim::ConstantSplit", "unknown dimension size", self | |
) | |
splits = [split_size] * (size // split_size) | |
leftover = size % split_size | |
if leftover: | |
splits.append(leftover) | |
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) | |
# TODO: It would be better to export this as a chunk directly, as this is | |
# less sensitive to changes in input size. | |
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this | |
# method, and use the desugared version | |
def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): | |
dim_size = symbolic_helper._get_tensor_dim_size(self, dim) | |
if dim_size is None: | |
return symbolic_helper._unimplemented( | |
"prim::ConstantChunk", "unknown dimension size", self | |
) | |
split_size = (dim_size + chunks - 1) // chunks | |
return prim_constant_split(g, self, split_size, dim) | |
def prim_shape(g: jit_utils.GraphContext, self): | |
return g.op("Shape", self) | |
def prim_max(g: jit_utils.GraphContext, self, other): | |
return _op_with_optional_float_cast(g, "Max", self, other, opset_before=12) | |
def prim_min(g: jit_utils.GraphContext, self, other=None): | |
if not other: | |
if symbolic_helper._is_packed_list(self): | |
self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) | |
return min(g, self) | |
return min(g, self, other) | |
def prim_data(g: jit_utils.GraphContext, self): | |
return self | |
def prim_layout(g: jit_utils.GraphContext, self): | |
# Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. | |
# Layout class defined in 'c10/core/Layout.h'. | |
return g.op("Constant", value_t=torch.tensor(0)) | |
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): | |
return None | |
def prim_list_unpack( | |
g: jit_utils.GraphContext, *inputs, **kwargs | |
) -> Optional[List[_C.Value]]: | |
if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": | |
# Cancel the previous node if it is ListConstruct by returning its inputs | |
# TODO(justinchuby): Use a public method in the helper module | |
return symbolic_helper._unpack_list(inputs[0]) | |
return None | |
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): | |
return None | |
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): | |
return None | |
# exists to refine the type of the Value | |
# if x is an optional Tensor, unchecked_cast will cast | |
# x to Tensor, so the rest of the graph knows that x is a Tensor | |
# this doesn't do anything in runtime and is a noop in ONNX | |
def prim_unchecked_cast(g: jit_utils.GraphContext, self): | |
return self | |
def prim_dtype(g: jit_utils.GraphContext, self): | |
scalar_type = symbolic_helper._try_get_scalar_type(self) | |
if scalar_type is None: | |
scalar_type = _type_utils.JitScalarType.FLOAT | |
# This node records a torch dtype as int | |
return g.op("Constant", value_t=torch.tensor(scalar_type)) | |
def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): | |
"""tolist is currently supported only for 1D input tensors. | |
dim_val and elem_ty_val represent dimension and type annotations | |
that need to match dimension and type of the input tensor. | |
""" | |
dim = symbolic_helper._maybe_get_const(dim_val, "i") | |
if dim > 1: | |
return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) | |
return input | |
# ----------------------------------------------------------------------------- | |
# Symbolic functions that need extra context | |
# ----------------------------------------------------------------------------- | |
def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: | |
output_type = g.original_node.output().type() | |
if isinstance(output_type, _C.DeviceObjType): | |
return None | |
return symbolic_helper._unimplemented( | |
"prim::device", | |
f"output type should be 'DeviceObjType', not '{output_type.kind()}'", | |
g.original_node.output(), | |
) | |
def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]: | |
node = g.original_node | |
env = g.env | |
params_dict = g.params_dict | |
operator_export_type = GLOBALS.operator_export_type | |
opset_version = GLOBALS.export_onnx_opset_version | |
old_blocks = tuple(node.blocks()) | |
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( | |
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) | |
) | |
for old_block, new_block_context in zip(old_blocks, new_block_contexts): | |
# Copy input metadata to subblock | |
# | |
# prim::Loop(iter, cond, input_1, ..., input_n) | |
# block0(iter, input_1, ..., input_n) | |
# | |
# For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. | |
for i, b_in in enumerate(old_block.inputs()): | |
if i == 0 and i < len(inputs): | |
b_in.setType(inputs[i].type()) | |
# For optional block inputs, they may switch between None not-None inside | |
# the loop body, so if the loop input is not optional, the block input may | |
# still need to be optional. | |
if ( | |
i > 0 | |
and (i + 1) < len(inputs) | |
and not isinstance(b_in.type(), _C.OptionalType) | |
): | |
b_in.setType(inputs[i + 1].type()) | |
torch._C._jit_pass_onnx_block( | |
old_block, | |
new_block_context.block, | |
operator_export_type, | |
env, | |
False, | |
) | |
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( | |
new_node, opset_version | |
) | |
# Run shape type inference for Loop after subblock is converted. | |
if GLOBALS.onnx_shape_inference: | |
torch._C._jit_pass_onnx_node_shape_type_inference( | |
new_node, params_dict, opset_version | |
) | |
return fixed_outputs | |
def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]: | |
n = g.original_node | |
block = g.block | |
env = g.env | |
params_dict = g.params_dict | |
operator_export_type = GLOBALS.operator_export_type | |
opset_version = GLOBALS.export_onnx_opset_version | |
static_if = inputs[0].node().kind() == "onnx::Constant" | |
if static_if: | |
# Fold static if | |
# | |
# The torch IR | |
# graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), | |
# %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... | |
# %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() | |
# %21 : Long(device=cpu) = aten::eq(%20, %64) | |
# %22 : Long(device=cpu) = prim::If(%21) | |
# block0(): | |
# %23 : Long(device=cpu) = aten::is_floating_point(%input.1) | |
# -> (%23) | |
# block1(): | |
# -> (%65) | |
# %input.53 : Tensor, %weight : Tensor = prim::If(%22) | |
# block0(): | |
# -> (%embedding_matrix.1, %input.1) | |
# block1(): | |
# -> (%input.1, %embedding_matrix.1) | |
# %26 : int[] = aten::size(%input.53) | |
# | |
# The converted ONNX graph | |
# %10 : Bool(device=cpu) = onnx::Constant[value={0}]() | |
# %14 : Bool(device=cpu) = onnx::Equal(%13, %8) | |
# %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() | |
# %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) | |
input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() | |
const_value = ( | |
all(input_flag) if isinstance(input_flag, list) else bool(input_flag) | |
) | |
block_idx = 0 if const_value else 1 | |
current_b = list(n.blocks())[block_idx] | |
env = torch._C._jit_pass_onnx_block( | |
current_b, | |
block, | |
operator_export_type, | |
env, | |
True, | |
) | |
if_output_list = list(n.outputs()) | |
current_b_list = list(current_b.outputs()) | |
final_b_list = [] | |
for idx in range(len(if_output_list)): | |
if current_b_list[idx] not in env: | |
raise errors.SymbolicValueError( | |
f"The sub block ATen output {current_b_list[idx]} is not in env.", | |
current_b_list[idx], | |
) # type:ignore[operator] | |
onnx_b = env[current_b_list[idx]] | |
final_b_list.append(onnx_b) | |
return final_b_list | |
else: | |
old_blocks = tuple(n.blocks()) | |
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( | |
g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) | |
) | |
for old_block, new_block_context in zip(old_blocks, new_block_contexts): | |
torch._C._jit_pass_onnx_block( | |
old_block, | |
new_block_context.block, | |
operator_export_type, | |
env, | |
False, | |
) | |
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( | |
new_node, opset_version | |
) | |
# Run shape type inference for If after subblock is converted. | |
if GLOBALS.onnx_shape_inference: | |
torch._C._jit_pass_onnx_node_shape_type_inference( | |
new_node, params_dict, opset_version | |
) | |
return fixed_outputs | |
def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): | |
node = g.original_node | |
if node.mustBeNone(): | |
return None | |
# This must go before checking for string values, because some device constants | |
# have string values, but we want to keep them as unconverted Device types so | |
# that eq() can work on them. | |
if isinstance(node.output().type(), _C.DeviceObjType): | |
return None | |
if node.kindOf("value") == "t": | |
return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) | |
if node.kindOf("value") == "s": | |
return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) | |
if node.output().type().isSubtypeOf( | |
_C.ListType.ofInts() | |
) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): | |
return g.op( | |
"Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) | |
) | |
if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): | |
str_constants = [ | |
g.op("Constant", value_s=s) | |
for s in symbolic_helper._node_get(node, "value") | |
] | |
return g.op("prim::ListConstruct", *str_constants) | |
raise errors.SymbolicValueError( | |
f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " | |
f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", | |
node.output(), | |
) | |
def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): | |
if device_value.node().kind() == "prim::device": | |
device = jit_utils.get_device_from_value(device_value.node().input()) | |
if device is not None: | |
return g.op("Constant", value_s=str(device)) | |
return symbolic_helper._unimplemented( | |
"prim::type", | |
"Device type cannot be statically determined.", | |
device_value, | |
) | |
def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): | |
node = g.original_node | |
block = g.block | |
env = g.env | |
return torch._C._jit_onnx_convert_pattern_from_subblock(block, node, env) | |
def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): | |
# ONNX does not have operators to *directly* manipulate real/imaginary components | |
# However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, | |
# which results in failures due to missing operators for complex numbers | |
# `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op | |
return input | |
def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): | |
# ONNX does not have operators to *directly* manipulate real/imaginary components | |
# However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, | |
# which results in failures due to missing operators for complex numbers | |
# While `aten::_conj` and `aten::conj_physical` raise exception when input is complex | |
if symbolic_helper.is_complex_value(input): | |
# FIXME(justinchuby): report correct name for symbolic being executed | |
return symbolic_helper._onnx_unsupported( | |
"aten::_conj, aten::conj_physical", | |
input, | |
) | |
# they can safely be implemented as no-op for real numbers only | |
return noop_complex_operators(g, input) | |
def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): | |
one = g.op("Constant", value_t=torch.tensor(1.0)) | |
if not symbolic_helper._is_none(eps): | |
eps = g.op( | |
"Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() | |
) | |
one_sub_eps = g.op("Sub", one, eps) | |
self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) | |
temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) | |
temporary_self_less_eps = g.op("Less", temporary_self, eps) | |
z = g.op("Where", temporary_self_less_eps, eps, temporary_self) | |
else: | |
z = self | |
sub = g.op("Sub", one, z) | |
div = g.op("Div", z, sub) | |
return g.op("Log", div) | |