"""This file exports ONNX ops for opset 14. Note [ONNX operators that are added/updated in opset 14] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ New operators: HardSwish, Trilu Updated operators: Reshape Add, Sub, Mul, Div GRU, LSTM, RNN BatchNorm, Cumsum, Relu """ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md from __future__ import annotations import functools from typing import Optional import torch from torch.onnx import _constants, _type_utils, symbolic_helper from torch.onnx._globals import GLOBALS from torch.onnx._internal import _beartype, jit_utils, registration __all__ = [ "hardswish", "tril", "triu", "reshape", "batch_norm", "quantized_hardswish", "scaled_dot_product_attention", ] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) @_onnx_symbolic("aten::hardswish") @symbolic_helper.parse_args("v") @_beartype.beartype def hardswish(g: jit_utils.GraphContext, self): return g.op("HardSwish", self) @_onnx_symbolic("aten::tril") @_beartype.beartype def tril(g: jit_utils.GraphContext, self, diagonal, out=None): return g.op("Trilu", self, diagonal, upper_i=0) @_onnx_symbolic("aten::triu") @_beartype.beartype def triu(g: jit_utils.GraphContext, self, diagonal, out=None): return g.op("Trilu", self, diagonal, upper_i=1) @_onnx_symbolic("aten::reshape") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v") @_beartype.beartype def reshape(g: jit_utils.GraphContext, self, shape): # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) @_onnx_symbolic("aten::batch_norm") @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") @_beartype.beartype def batch_norm( g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled, ): if ( torch.is_autocast_enabled() and not symbolic_helper.args_have_same_dtype( [input, weight, bias, running_mean, running_var] ) and GLOBALS.export_onnx_opset_version < 15 ): return symbolic_helper._onnx_opset_unsupported_detailed( "BatchNormalization", 14, 15, "All input tensors must have the same `dtype`." " Turn off Autocast or export using opset version 15.", input, ) symbolic_helper.check_training_mode(training, "batch_norm") weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( g, input, weight, bias, running_mean, running_var ) out = g.op( "BatchNormalization", input, weight, bias, running_mean, running_var, epsilon_f=eps, momentum_f=1 - momentum, training_mode_i=0 if not training else 1, outputs=1 if not training else 3, ) if not training: return out else: res, new_running_mean, new_running_var = out new_running_mean.setType(running_mean.type()) new_running_var.setType(running_var.type()) return res @_onnx_symbolic("quantized::hardswish") @_beartype.beartype def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) output = hardswish(g, x) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) # Ported from # https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 # aten_scaled_dot_product_attention # NOTE: Need op.Trilu @_onnx_symbolic("aten::scaled_dot_product_attention") @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v") @_beartype.beartype def scaled_dot_product_attention( g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value, value: torch._C.Value, attn_mask: Optional[torch._C.Value] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[torch._C.Value] = None, ): assert (not is_causal) or ( is_causal and symbolic_helper._is_none(attn_mask) ), "is_causal and attn_mask cannot be set at the same time" scale = symbolic_helper._maybe_get_const(scale, "f") if symbolic_helper._is_none(scale): scale = _attention_scale(g, query) if is_causal: attn_mask = _causal_attention_mask(g, query, key) # Swap the last two axes of key # NOTE: onnx-script has different logic here, because the attribute perms in # transpose needs list of ints key_shape_builtin = symbolic_helper._get_tensor_rank(key) key_transposed_axes = list(range(key_shape_builtin)) key_transposed_axes[-1], key_transposed_axes[-2] = ( key_transposed_axes[-2], key_transposed_axes[-1], ) key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) if symbolic_helper._is_none(attn_mask): mul_qk_add = mul_qk elif ( _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL ): # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) elif _type_utils.JitScalarType.from_value(attn_mask) in ( _type_utils.JitScalarType.FLOAT, _type_utils.JitScalarType.HALF, _type_utils.JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) else: raise ValueError( f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" ) attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) if dropout_p != 0: attn_weight = g.op( "Dropout", attn_weight, g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), ) return g.op("MatMul", attn_weight, value) @_beartype.beartype def _attention_scale( g: jit_utils.GraphContext, query: torch._C.Value ) -> torch._C.Value: """Calculate the scale factor for the attention result. Args: query: Tensor of shape [..., L, E] Returns: Scalar scale factor := 1 / math.sqrt(query.size(-1)) """ query_shape = g.op("Shape", query) query_shape_last = g.op( "Slice", query_shape, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), g.op( "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) ), ) embedding_size = g.op( "Cast", query_shape_last, to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), ) const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) # Add a Cast to convert the scale back to original type scale = g.op( "Cast", scale, to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), ) return scale @_beartype.beartype def _causal_attention_mask( g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value ) -> torch._C.Value: """Create a causal mask for the given query and key tensors. Equivalent to:: mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_mask = torch.zeros(L, S, dtype=torch.float) attn_mask = attn_mask.masked_fill(not mask, -float('inf')) Args: query: Tensor of shape [..., L, E] key: Tensor of shape [..., S, E] Returns: Tensor of shape [L, S] """ query_shape = g.op("Shape", query) key_shape = g.op("Shape", key) last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) target_length = g.op("Slice", query_shape, second_last_idx, last_idx) source_length = g.op("Slice", key_shape, second_last_idx, last_idx) # attn_mask = torch.ones(L, S) := { size = g.op("Concat", target_length, source_length, axis_i=0) const_one = g.op("Constant", value_t=torch.tensor([1.0])) attn_mask = g.op("Expand", const_one, size) # } attn_mask = g.op("Trilu", attn_mask, upper_i=0) # The causal mask has 0s in the lower triangle and -inf in the upper triangle. const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op( "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero ) return attn_mask