Spaces:
Running
Running
"""This file exports ONNX ops for opset 17. | |
Note [ONNX Operators that are added/updated in opset 17] | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set | |
New operators: | |
BlackmanWindow | |
DFT | |
HammingWindow | |
HannWindow | |
LayerNormalization | |
MelWeightMatrix | |
STFT | |
SequenceMap | |
""" | |
import functools | |
from typing import Optional, Sequence | |
import torch | |
from torch import _C | |
from torch.onnx import _type_utils, errors, symbolic_helper | |
from torch.onnx._internal import _beartype, jit_utils, registration | |
# EDITING THIS FILE? READ THIS FIRST! | |
# see Note [Edit Symbolic Files] in README.md | |
__all__ = ["layer_norm", "stft"] | |
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) | |
def layer_norm( | |
g: jit_utils.GraphContext, | |
input: _C.Value, | |
normalized_shape: Sequence[int], | |
weight: _C.Value, | |
bias: _C.Value, | |
eps: float, | |
cudnn_enable: bool, | |
): | |
# normalized_shape: input shape from an expected input of size | |
# axis: The first normalization dimension. | |
# layer_norm normalizes on the last D dimensions, | |
# where D is the size of normalized_shape | |
axis = -len(normalized_shape) | |
scalar_type = _type_utils.JitScalarType.from_value( | |
input, _type_utils.JitScalarType.FLOAT | |
) | |
dtype = scalar_type.dtype() | |
if symbolic_helper._is_none(weight): | |
weight_value = torch.ones(normalized_shape, dtype=dtype) | |
weight = g.op("Constant", value_t=weight_value) | |
if symbolic_helper._is_none(bias): | |
bias_value = torch.zeros(normalized_shape, dtype=dtype) | |
bias = g.op("Constant", value_t=bias_value) | |
return g.op( | |
"LayerNormalization", | |
input, | |
weight, | |
bias, | |
epsilon_f=eps, | |
axis_i=axis, | |
) | |
def _compute_edge_sizes(n_fft, window_size): | |
"""Helper function to compute the sizes of the edges (left and right) | |
of a given window centered within an FFT size.""" | |
left = (n_fft - window_size) // 2 | |
right = n_fft - left - window_size | |
return left, right | |
def stft( | |
g: jit_utils.GraphContext, | |
input: _C.Value, | |
n_fft: int, | |
hop_length: Optional[int] = None, | |
win_length: Optional[int] = None, | |
window: Optional[_C.Value] = None, | |
normalized: bool = False, | |
onesided: Optional[bool] = True, | |
return_complex: Optional[bool] = False, | |
) -> _C.Value: | |
"""Associates `torch.stft` with the `STFT` ONNX operator. | |
Note that torch.stft calls _VF.stft, without centering or padding options. | |
Hence, this function does not contain these two arguments. | |
See torch.stft source code for more info. | |
Args: | |
g: Graph to write the ONNX representation into | |
input: Input tensor for the transformation | |
n_fft: FFT size | |
hop_length: Size of the hop. Defaults to `floot(n_fft // 4)` | |
win_length: Size of the analysis window. Defaults to `n_fft` | |
window: Analysis window. Defaults to a window of all ones | |
normalized: Whether to return a normalized STFT | |
onesided: Whether to return only half (+1) of the results, given the | |
symmetry of the STFT | |
return_complex: Whether to return the complex value (Note: Must be | |
`False` or `None`) | |
Returns: | |
op: Operator for torch.stft associated with STFT (ONNX) | |
""" | |
# Checks | |
if return_complex: | |
raise errors.SymbolicValueError( | |
msg="STFT does not currently support complex types", value=input | |
) | |
# Get STFT sizes | |
frame_step_value = hop_length if hop_length is not None else n_fft // 4 | |
frame_step_const = g.op( | |
"Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64) | |
) | |
frame_length_const = g.op( | |
"Constant", value_t=torch.tensor(n_fft, dtype=torch.int64) | |
) | |
# Pre-process input if needed | |
signal = input | |
signal_rank = symbolic_helper._get_tensor_rank(signal) | |
if signal_rank == 1: | |
# Add batch dimension | |
signal = g.op( | |
"Unsqueeze", | |
signal, | |
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), | |
) | |
elif signal_rank > 2: | |
raise errors.SymbolicValueError( | |
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. " | |
f"Current rank of signal is {signal_rank}, please reduce it.", | |
value=input, | |
) | |
# Get window and make sure it's the same size as `win_length` or `n_fft` | |
n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) | |
if n_win is not None: | |
win_length_default = win_length if win_length else n_fft | |
assert n_win == win_length_default, ( | |
"Analysis window size must equal `win_length` or `n_fft`. " | |
f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})", | |
) | |
# Center window around zeros if needed (required by ONNX's STFT) | |
if n_win < n_fft: | |
left, right = _compute_edge_sizes(n_fft, n_win) | |
left_win = g.op("Constant", value_t=torch.zeros(left)) | |
right_win = g.op("Constant", value_t=torch.zeros(right)) | |
window = g.op("Concat", left_win, window, right_win, axis_i=0) | |
# Create window, if needed | |
if symbolic_helper._is_none(window): | |
if win_length: | |
if win_length > n_fft: | |
raise errors.SymbolicValueError( | |
msg="The analysis window can't be longer than the size of the FFT. " | |
f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.", | |
value=input, | |
) | |
# Center window, if needed | |
left, right = _compute_edge_sizes(n_fft, win_length) | |
torch_window = torch.hstack( | |
(torch.zeros(left), torch.ones(win_length), torch.zeros(right)) | |
) | |
else: | |
# Rectangle window | |
torch_window = torch.ones(n_fft) | |
assert torch_window.shape[0] == n_fft | |
window = g.op("Constant", value_t=torch_window) | |
window = g.op( | |
"Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() | |
) | |
# Run STFT | |
result = g.op( | |
"STFT", | |
signal, | |
frame_step_const, | |
window, | |
frame_length_const, | |
onesided_i=1 if onesided is None or onesided else 0, | |
) | |
# Transpose to mimic torch.stft's behavior | |
result = g.op("Transpose", result, perm_i=[0, 2, 1, 3]) | |
# Remove batch dimension, if needed | |
if signal_rank == 1: | |
result = g.op( | |
"Squeeze", | |
result, | |
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), | |
) | |
# Normalize, if needed | |
if normalized: | |
sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype())) | |
result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft)) | |
return result | |